From 86777b5e2f651d7f7c46db919beb13893743a5b5 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 30 Apr 2025 11:16:29 -0700 Subject: [PATCH] Support `AOPerModuleConfig` and `include_embedding` (#37802) * Support `AOPerModuleConfig` and include_embedding Summary: This PR adds support per module configuration for torchao Also added per module quantization examples: 1. Quantizing different layers with different quantization configs 2. Skip quantization for certain layers Test Plan: python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding python tests/quantization/torchao_integration/test_torchao.py -k test_per_module_config_skip Reviewers: Subscribers: Tasks: Tags: * format * format * inlcude embedding remove input embedding from module not to convert * more docs * Update docs/source/en/quantization/torchao.md Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_torchao.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_torchao.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- docs/source/en/quantization/torchao.md | 78 ++++++++++++++++++- .../quantizers/quantizer_torchao.py | 28 ++++++- src/transformers/utils/quantization_config.py | 6 ++ .../torchao_integration/test_torchao.py | 61 +++++++++++++++ 4 files changed, 167 insertions(+), 6 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 42fed458f7..bee2e008b9 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -40,6 +40,8 @@ torchao supports the [quantization techniques](https://github.com/pytorch/ao/blo - A16W4 Int4 Weight Only Quantization - Autoquantization +torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules. + Check the table below to see if your hardware is compatible. @@ -89,7 +91,7 @@ We'll show examples for recommended quantization methods based on hardwares, e.g ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer -from torchao.quantization import Float8DynamicActivationFloat8WeightConfig +from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig quant_config = Float8DynamicActivationFloat8WeightConfig() # or float8 weight only quantization @@ -149,7 +151,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer -from torchao.quantization import Int8DynamicActivationInt8WeightConfig +from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig quant_config = Int8DynamicActivationInt8WeightConfig() # or int8 weight only quantization @@ -179,7 +181,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer -from torchao.quantization import GemliteUIntXWeightOnlyConfig +from torchao.quantization import GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig # For batch size N, we recommend gemlite, which may require autotuning # default is 4 bit, 8 bit is also supported by passing `bit_width=8` @@ -216,7 +218,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer -from torchao.quantization import Int8DynamicActivationInt8WeightConfig +from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig quant_config = Int8DynamicActivationInt8WeightConfig() # quant_config = Int8WeightOnlyConfig() @@ -272,6 +274,74 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) +### Per Module Quantization +#### 1. Skip quantization for certain layers +With `AOPerModuleConfig` we can specify a default configuration for all layers while skipping quantization for certain layers. +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +model_id = "meta-llama/Llama-3.1-8B-Instruct" + +from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig +config = Int4WeightOnlyConfig(group_size=128) + +# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj` +quant_config = AOPerModuleConfig({"_default": config, "model.layers.0.self_attn.q_proj": None}) +quantization_config = TorchAoConfig(quant_type=quant_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized +print("quantized model:", quantized_model) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Manual Testing +prompt = "Hey, are you conscious? Can you talk to me?" +inputs = tokenizer(prompt, return_tensors="pt").to("cuda") +generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) +output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text) +``` + +#### 2. Quantizing different layers with different quantization configs +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +model_id = "facebook/opt-125m" + +from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType + +weight_dtype = torch.int8 +granularity = PerAxis(0) +mapping_type = MappingType.ASYMMETRIC +embedding_config = IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, +) +linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128) +quant_config = AOPerModuleConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None}) +# set `include_embedding` to True in order to include embedding in quantization +# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well +quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True) +quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +print("quantized model:", quantized_model) +# make sure embedding is quantized +print("embed_tokens weight:", quantized_model.model.decoder.embed_tokens.weight) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Manual Testing +prompt = "Hey, are you conscious? Can you talk to me?" +inputs = tokenizer(prompt, return_tensors="pt").to("cpu") +generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, cache_implementation="static") +output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text) +``` + ### Autoquant If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API. diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 5b23a6173d..31d764c302 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -185,6 +185,10 @@ class TorchAoHfQuantizer(HfQuantizer): self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) + if self.quantization_config.include_embedding: + input_emb = model.get_input_embeddings() + input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)] + self.modules_to_not_convert = [x for x in self.modules_to_not_convert if x not in input_emb_names] return def check_quantized_param( @@ -206,9 +210,12 @@ class TorchAoHfQuantizer(HfQuantizer): # We don't quantize weights that we offload return False else: - # we only quantize the weight of nn.Linear + # we only quantize the weight of nn.Linear and nn.Embedding module, tensor_name = get_module_from_name(model, param_name) - return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + _QUANTIZABLE = [torch.nn.Linear] + if self.quantization_config.include_embedding: + _QUANTIZABLE.append(torch.nn.Embedding) + return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight") def create_quantized_param( self, @@ -240,6 +247,23 @@ class TorchAoHfQuantizer(HfQuantizer): module._parameters[tensor_name] = torch.nn.Parameter( param_value, requires_grad=param_value.requires_grad ).to(device=target_device) + # handle AOPerModuleConfig, introduced in torchao 0.11.0+ + if self.quantization_config._get_ao_version() > version.Version("0.10.0"): + from torchao.quantization import AOPerModuleConfig + + config = self.quantization_config.get_apply_tensor_subclass() + if isinstance(config, AOPerModuleConfig): + module_fqn, _ = param_name.rsplit(".", 1) + c = None + if module_fqn in config.module_fqn_to_config: + c = config.module_fqn_to_config[module_fqn] + else: + c = config.module_fqn_to_config.get("_default", None) + if c is not None: + # filter_fn: not filtering out any modules + quantize_(module, c, filter_fn=lambda x, fqn: True) + return + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) def _process_model_after_weight_loading(self, model, **kwargs): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 72b3837142..aa1f714c46 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1554,6 +1554,7 @@ class TorchAoConfig(QuantizationConfigMixin): quant_type: Union[str, "AOBaseConfig"] # noqa: F821 modules_to_not_convert: Optional[List] quant_type_kwargs: Dict[str, Any] + include_embedding: bool """This is a config class for torchao quantization/sparsity techniques. @@ -1565,6 +1566,9 @@ class TorchAoConfig(QuantizationConfigMixin): modules_to_not_convert (`list`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. + inlcude_embedding (`bool`, default to `False`): + Whether to include embedding in quantization or not, input embedding will be removed from + the module_not_to_convert list as well if this flag is set. kwargs (`Dict[str, Any]`, *optional*): The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in @@ -1609,12 +1613,14 @@ class TorchAoConfig(QuantizationConfigMixin): self, quant_type: Union[str, "AOBaseConfig"], # noqa: F821 modules_to_not_convert: Optional[List] = None, + include_embedding: bool = False, **kwargs, ): self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs) + self.include_embedding = include_embedding self.post_init() @staticmethod diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 3f24d9a95a..61d569a040 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -38,6 +38,13 @@ if is_torchao_available(): AffineQuantizedTensor, TensorCoreTiledLayout, ) + from torchao.quantization import ( + AOPerModuleConfig, + Int8WeightOnlyConfig, + IntxWeightOnlyConfig, + MappingType, + PerAxis, + ) from torchao.quantization.autoquant import AQMixin if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"): @@ -193,6 +200,60 @@ class TorchAoTest(unittest.TestCase): ] self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + @require_torchao_version_greater_or_equal("0.11.0") + def test_include_embedding(self): + weight_dtype = torch.int8 + granularity = PerAxis(0) + mapping_type = MappingType.ASYMMETRIC + embedding_config = IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + ) + config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config}) + # need set `include_embedding` to True + quant_config = TorchAoConfig(quant_type=config, include_embedding=True) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # making sure embedding is quantized + self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + @require_torchao_version_greater_or_equal("0.11.0") + def test_per_module_config_skip(self): + linear_config = Int8WeightOnlyConfig() + config = AOPerModuleConfig({"_default": linear_config, "model.layers.0.self_attn.q_proj": None}) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # making sure `model.layers.0.self_attn.q_proj` is skipped + self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + @require_torch_gpu class TorchAoGPUTest(TorchAoTest):