From 4bb49d4e00a2fe6ecfb644c424dc8d88edc02590 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 30 Sep 2024 02:30:29 -0700 Subject: [PATCH] =?UTF-8?q?Enable=20non-safetensor=20ser/deser=20for=20Tor?= =?UTF-8?q?chAoConfig=20quantized=20model=20=F0=9F=94=B4=20=20(#33456)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model Summary: After https://github.com/huggingface/huggingface_hub/pull/2440 we added non-safetensor serialization and deserialization in huggingface, with this we can now add the support in transformers Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor see README for more details Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags: * formatting * formatting * minor fix * formatting * address comments * comments * minor fix * update doc * refactor compressed tensor quantizer --- docs/source/en/quantization/torchao.md | 50 +++++++++++++++++- src/transformers/modeling_utils.py | 39 ++++++++++---- src/transformers/quantizers/base.py | 3 +- src/transformers/quantizers/quantizer_aqlm.py | 3 +- src/transformers/quantizers/quantizer_awq.py | 3 +- .../quantizers/quantizer_bnb_4bit.py | 5 +- .../quantizers/quantizer_bnb_8bit.py | 7 ++- .../quantizer_compressed_tensors.py | 3 +- src/transformers/quantizers/quantizer_eetq.py | 3 +- .../quantizers/quantizer_fbgemm_fp8.py | 3 +- src/transformers/quantizers/quantizer_gptq.py | 3 +- src/transformers/quantizers/quantizer_hqq.py | 5 +- .../quantizers/quantizer_quanto.py | 3 +- .../quantizers/quantizer_torchao.py | 16 ++++-- src/transformers/utils/quantization_config.py | 52 +++++++++++-------- 15 files changed, 134 insertions(+), 64 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 99ad60a923..cd1d0188c3 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -11,7 +11,7 @@ rendered properly in your Markdown viewer. # TorchAO -[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes) +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). Before you begin, make sure the following libraries are installed with their latest version: @@ -21,6 +21,7 @@ pip install --upgrade torch torchao ```py +import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer model_name = "meta-llama/Meta-Llama-3-8B" @@ -40,6 +41,51 @@ quantized_model = torch.compile(quantized_model, mode="max-autotune") output = quantized_model.generate(**input_ids, max_new_tokens=10) print(tokenizer.decode(output[0], skip_special_tokens=True)) + +# benchmark the performance +import torch.utils.benchmark as benchmark + +def benchmark_fn(f, *args, **kwargs): + # Manual warmup + for _ in range(5): + f(*args, **kwargs) + + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "f": f}, + num_threads=torch.get_num_threads(), + ) + return f"{(t0.blocked_autorange().mean):.3f}" + +MAX_NEW_TOKENS = 1000 +print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)) + +bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) +bf16_model = torch.compile(bf16_model, mode="max-autotune") +print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)) + ``` -torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working. +## Serialization and Deserialization +torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions. + +The reason why it does not support safe tensor serialization is that wrapper tensor subclass allows maximum flexibility so we want to make sure the effort of supporting new format of quantized Tensor is low, while safe tensor optimizes for maximum safety (no user code execution), it also means we have to make sure to manually support new quantization format. + +```py +# save quantized model locally +output_dir = "llama3-8b-int4wo-128" +quantized_model.save_pretrained(output_dir, safe_serialization=False) + +# push to huggingface hub +# save_to = "{user_id}/llama3-8b-int4wo-128" +# quantized_model.push_to_hub(save_to, safe_serialization=False) + +# load quantized model +ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id +loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda") + + +# confirm the speedup +loaded_quantized_model = torch.compile(loaded_quantized_model, mode="max-autotune") +print("loaded int4wo-128 model:", benchmark_fn(loaded_quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)) +``` diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d0f4239c38..40b6ff5d18 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,7 +540,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) -def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False): +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], + is_quantized: bool = False, + map_location: Optional[Union[str, torch.device]] = None, +): """ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. """ @@ -555,13 +559,18 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool ) return safe_load_file(checkpoint_file) try: - if ( - (is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0) - or (is_fsdp_enabled() and not is_local_dist_rank_0()) - ) and not is_quantized: - map_location = "meta" - else: - map_location = "cpu" + if map_location is None: + if ( + ( + is_deepspeed_zero3_enabled() + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ) + or (is_fsdp_enabled() and not is_local_dist_rank_0()) + ) and not is_quantized: + map_location = "meta" + else: + map_location = "cpu" extra_args = {} # mmap can only be used with files serialized with zipfile-based format. if ( @@ -2564,7 +2573,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix hf_quantizer = getattr(self, "hf_quantizer", None) quantization_serializable = ( - hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable + hf_quantizer is not None + and isinstance(hf_quantizer, HfQuantizer) + and hf_quantizer.is_serializable(safe_serialization=safe_serialization) ) if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: @@ -4479,7 +4490,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. if shard_file in disk_only_shard_files: continue - state_dict = load_state_dict(shard_file, is_quantized=is_quantized) + map_location = None + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type == "int4_weight_only" + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 3ee28ada1b..73b3dbd8b2 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -217,9 +217,8 @@ class HfQuantizer(ABC): @abstractmethod def _process_model_after_weight_loading(self, model, **kwargs): ... - @property @abstractmethod - def is_serializable(self): ... + def is_serializable(self, safe_serialization=None): ... @property @abstractmethod diff --git a/src/transformers/quantizers/quantizer_aqlm.py b/src/transformers/quantizers/quantizer_aqlm.py index 5300716161..9d1d6f7e89 100644 --- a/src/transformers/quantizers/quantizer_aqlm.py +++ b/src/transformers/quantizers/quantizer_aqlm.py @@ -93,6 +93,5 @@ class AqlmHfQuantizer(HfQuantizer): ) return False - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return True diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index f9e4444f07..9c66ba385a 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -106,8 +106,7 @@ class AwqQuantizer(HfQuantizer): model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config) - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): # AWQ through auto-awq has been always serializable, except if the model is fused. if self.quantization_config.do_fuse: logger.warning("You cannot save an AWQ model that uses fused modules!") diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 73e7664aeb..eed45192e7 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -320,11 +320,10 @@ class Bnb4BitHfQuantizer(HfQuantizer): # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_loaded_in_4bit = True - model.is_4bit_serializable = self.is_serializable + model.is_4bit_serializable = self.is_serializable() return model - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3") if not _is_4bit_serializable: diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 65d97716d0..020ff7cc62 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -210,7 +210,7 @@ class Bnb8BitHfQuantizer(HfQuantizer): raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") new_value = param_value.to("cpu") - if self.pre_quantized and not self.is_serializable: + if self.pre_quantized and not self.is_serializable(): raise ValueError( "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." @@ -238,7 +238,7 @@ class Bnb8BitHfQuantizer(HfQuantizer): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_loaded_in_8bit = True - model.is_8bit_serializable = self.is_serializable + model.is_8bit_serializable = self.is_serializable() return model def _process_model_before_weight_loading( @@ -282,8 +282,7 @@ class Bnb8BitHfQuantizer(HfQuantizer): model.config.quantization_config = self.quantization_config - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( "0.37.2" ) diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index 5531838e56..347be5c665 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -72,6 +72,5 @@ class CompressedTensorsHfQuantizer(HfQuantizer): def is_trainable(self): return False - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return False diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py index 7be0a6bd9e..602df62c01 100644 --- a/src/transformers/quantizers/quantizer_eetq.py +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -161,8 +161,7 @@ class EetqHfQuantizer(HfQuantizer): model.config.quantization_config = self.quantization_config - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return True @property diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py index 6591a56fce..07d5ce87ef 100644 --- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py +++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py @@ -196,8 +196,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer): not_missing_keys.append(missing) return [k for k in missing_keys if k not in not_missing_keys] - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return True @property diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py index ffc6f2090a..233a5279d3 100644 --- a/src/transformers/quantizers/quantizer_gptq.py +++ b/src/transformers/quantizers/quantizer_gptq.py @@ -89,6 +89,5 @@ class GptqHfQuantizer(HfQuantizer): def is_trainable(self, model: Optional["PreTrainedModel"] = None): return True - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return True diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 14be75369d..cd32a99c00 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -188,11 +188,10 @@ class HqqHfQuantizer(HfQuantizer): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_hqq_quantized = True - model.is_hqq_serializable = self.is_serializable + model.is_hqq_serializable = self.is_serializable() return model - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return False @property diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py index e7e2219ab6..ae113f714a 100644 --- a/src/transformers/quantizers/quantizer_quanto.py +++ b/src/transformers/quantizers/quantizer_quanto.py @@ -195,6 +195,5 @@ class QuantoHfQuantizer(HfQuantizer): def is_trainable(self, model: Optional["PreTrainedModel"] = None): return False - @property - def is_serializable(self): + def is_serializable(self, safe_serialization=None): return False diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 02ea8294a2..f6bf431aa0 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -160,9 +160,19 @@ class TorchAoHfQuantizer(HfQuantizer): """No process required for torchao quantized model""" return - @property - def is_serializable(self): - return False + def is_serializable(self, safe_serialization=None): + if safe_serialization: + logger.warning( + "torchao quantized model does not support safe serialization, " + "please set `safe_serialization` to False" + ) + return False + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( + "0.25.0" + ) + if not _is_torchao_serializable: + logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + return _is_torchao_serializable @property def is_trainable(self): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 23a983af74..19166f9ed9 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -95,7 +95,6 @@ class QuantizationConfigMixin: Returns: [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. """ - config = cls(**config_dict) to_remove = [] @@ -1235,24 +1234,11 @@ class TorchAoConfig(QuantizationConfigMixin): self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert - self.kwargs = kwargs - self._STR_TO_METHOD = {} - if is_torchao_available(): - from torchao.quantization import ( - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, - ) - - self._STR_TO_METHOD = { - "int4_weight_only": int4_weight_only, - "int8_weight_only": int8_weight_only, - "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, - } + # when we load from serailized config, "quant_type_kwargs" will be the key + if "quant_type_kwargs" in kwargs: + self.quant_type_kwargs = kwargs["quant_type_kwargs"] else: - raise ValueError( - "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" - ) + self.quant_type_kwargs = kwargs self.post_init() @@ -1263,26 +1249,46 @@ class TorchAoConfig(QuantizationConfigMixin): if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"): raise ValueError("Requires torchao 0.4.0 version and above") - if self.quant_type not in self._STR_TO_METHOD.keys(): + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() + if self.quant_type not in _STR_TO_METHOD.keys(): raise ValueError( f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer." ) - method = self._STR_TO_METHOD[self.quant_type] + method = _STR_TO_METHOD[self.quant_type] sig = signature(method) all_kwargs = [ param.name for param in sig.parameters.values() if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD] ] - for k in self.kwargs: + for k in self.quant_type_kwargs: if k not in all_kwargs: raise ValueError( f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}" ) + def _get_torchao_quant_type_to_method(self): + if is_torchao_available(): + from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + ) + + return { + "int4_weight_only": int4_weight_only, + "int8_weight_only": int8_weight_only, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + def get_apply_tensor_subclass(self): - return self._STR_TO_METHOD[self.quant_type](**self.kwargs) + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() + return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) def __repr__(self): - return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})" + return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.quant_type_kwargs.items())})"