diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 46fb0f8cbb..06017c3f3e 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -22,6 +22,12 @@ pip install --upgrade torch torchao transformers By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type. +## Manually Choose Quantization Types and Settings + +`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed. + +Users can manually specify the quantization types and settings they want to use: + ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer @@ -41,19 +47,14 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen print(tokenizer.decode(output[0], skip_special_tokens=True)) # benchmark the performance -import torch.utils.benchmark as benchmark +from torch._inductor.utils import do_bench_using_profiling +from typing import Callable -def benchmark_fn(f, *args, **kwargs): - # Manual warmup - for _ in range(5): - f(*args, **kwargs) - - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", - globals={"args": args, "kwargs": kwargs, "f": f}, - num_threads=torch.get_num_threads(), - ) - return f"{(t0.blocked_autorange().mean):.3f}" +def benchmark_fn(func: Callable, *args, **kwargs) -> float: + """Thin wrapper around do_bench_using_profiling""" + no_args = lambda: func(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 MAX_NEW_TOKENS = 1000 print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) @@ -64,6 +65,47 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke ``` +## Automatically Select Quantization Types + +`torchao` also provies `autoquant` feature that automatically chooses a quantization type for quantizable layers such as linear based on microbenchmarks of quantizing and compiling a single linear layer. + +```py +import torch +from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer + +model_name = "meta-llama/Meta-Llama-3-8B" +quantization_config = TorchAoConfig("autoquant", min_sqnr=None) +quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_text = "What are we having for dinner?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +# auto-compile the quantized model with `cache_implementation="static"` to get speedup +output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") +# Due to some implementation details we are explicitly calling this now, we may refactor our code and remove this in the future +quantized_model.finalize_autoquant() +print(tokenizer.decode(output[0], skip_special_tokens=True)) + +# benchmark the performance +from torch._inductor.utils import do_bench_using_profiling +from typing import Callable + +def benchmark_fn(func: Callable, *args, **kwargs) -> float: + """Thin wrapper around do_bench_using_profiling""" + no_args = lambda: func(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 + +MAX_NEW_TOKENS = 1000 +print("autoquantized model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) + +bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) +output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile +print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) + +``` + ## Serialization and Deserialization torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e3b2209f02..0d03de2add 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4914,7 +4914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix device_map is not None and hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantization_config.quant_type == "int4_weight_only" + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] ): map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) state_dict = load_state_dict( diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index bcc9c57dfa..8439e68a90 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -129,6 +129,7 @@ class TorchAoHfQuantizer(HfQuantizer): "int4_weight_only": CustomDtype.INT4, "int8_weight_only": torch.int8, "int8_dynamic_activation_int8_weight": torch.int8, + "autoquant": None, } return map_to_target_dtype[self.quantization_config.quant_type] else: @@ -161,6 +162,9 @@ class TorchAoHfQuantizer(HfQuantizer): state_dict: Dict[str, Any], **kwargs, ) -> bool: + if self.quantization_config.quant_type == "autoquant": + return False + param_device = kwargs.pop("param_device", None) # check if the param_name is not in self.modules_to_not_convert if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): @@ -186,6 +190,9 @@ class TorchAoHfQuantizer(HfQuantizer): Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. """ + if self.quantization_config.quant_type == "autoquant": + return + from torchao.quantization import quantize_ module, tensor_name = get_module_from_name(model, param_name) @@ -200,6 +207,15 @@ class TorchAoHfQuantizer(HfQuantizer): def _process_model_after_weight_loading(self, model, **kwargs): """No process required for torchao quantized model""" + if self.quantization_config.quant_type == "autoquant": + from torchao import autoquant + from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST + + model = torch.compile(model, mode="max-autotune") + model = autoquant( + model, qtensor_class_list=ALL_AUTOQUANT_CLASS_LIST, **self.quantization_config.quant_type_kwargs + ) + return model return def is_serializable(self, safe_serialization=None): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 3fafca29b9..2ac53dc315 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1453,7 +1453,7 @@ class TorchAoConfig(QuantizationConfigMixin): Args: quant_type (`str`): - The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. + The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`. modules_to_not_convert (`list`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. @@ -1465,9 +1465,31 @@ class TorchAoConfig(QuantizationConfigMixin): Example: ```python + from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + + # specific quantization method quantization_config = TorchAoConfig("int4_weight_only", group_size=32) # int4_weight_only quant is only working with *torch.bfloat16* dtype right now model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + + # autoquant + # `autoquant` is a convenient way for users to search for the best quantization for each layer + # `min_sqnr` is an option to control the accuracy of the model, higher value means the model is more + # accurate, we can start with 30 and adjust it to larger or smaller (e.g. 40, 20) + # defaults to None, which means we'll try to get the best performing quantized model without + # considering accuracy + quantization_config = TorchAoConfig("autoquant", min_sqnr=30) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + # run through example inputs, quantization methods will be selected based on the shape of example input + tokenizer = AutoTokenizer.from_pretrained(model_name) + input_text = "What are we having for dinner?" + input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + MAX_NEW_TOKENS = 1000 + model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static") + # manually ran finalize_autoquant if needed + if hasattr(quantized_model, "finalize_autoquant"): + print("finalizing autoquant") + quantized_model.finalize_autoquant() ``` """ @@ -1488,8 +1510,8 @@ class TorchAoConfig(QuantizationConfigMixin): Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. """ if is_torchao_available(): - if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"): - raise ValueError("Requires torchao 0.4.0 version and above") + if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"): + raise ValueError("Requires torchao 0.7.0 version and above") else: raise ValueError( "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" @@ -1517,6 +1539,7 @@ class TorchAoConfig(QuantizationConfigMixin): def _get_torchao_quant_type_to_method(self): if is_torchao_available(): from torchao.quantization import ( + autoquant, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, @@ -1526,6 +1549,7 @@ class TorchAoConfig(QuantizationConfigMixin): "int4_weight_only": int4_weight_only, "int8_weight_only": int8_weight_only, "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + "autoquant": autoquant, } else: raise ValueError( diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 1708550cf0..60694924cd 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -31,10 +31,12 @@ if is_torch_available(): import torch if is_torchao_available(): + # renamed in torchao 0.7.0, please install the latest torchao from torchao.dtypes import ( AffineQuantizedTensor, TensorCoreTiledLayout, ) + from torchao.quantization.autoquant import AQMixin def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024): @@ -42,7 +44,12 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024 test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) test_module.assertEqual(weight.quant_min, 0) test_module.assertEqual(weight.quant_max, 15) - test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout)) + test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout)) + + +def check_autoquantized(test_module, qlayer): + weight = qlayer.weight + test_module.assertTrue(isinstance(weight, AQMixin)) def check_forward(test_module, model, batch_size=1, context_size=1024): @@ -248,6 +255,33 @@ class TorchAoTest(unittest.TestCase): EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + def test_autoquant(self): + """ + Simple LLM model testing autoquant + """ + quant_config = TorchAoConfig("autoquant") + + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map=torch_device, + quantization_config=quant_config, + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + output = quantized_model.generate( + **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" + ) + quantized_model.finalize_autoquant() + + check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj) + + EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready' + output = quantized_model.generate( + **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" + ) + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + @require_torch_gpu @require_torchao