diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index b2da1aba88..31e2d4f020 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -20,18 +20,95 @@ Install torchao with the following command. pip install --upgrade torch torchao transformers ``` -torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights. +torchao supports many quantization types for different data types (int4, float8, weight only, etc.). +Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations. +And full access to the techniques offered in the torchao library. You can manually choose the quantization types and settings or automatically select the quantization types. + Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. > [!TIP] > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. +In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers: + +```py +import torch +from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer +from torchao.quantization import Int4WeightOnlyConfig + +# Using AOBaseConfig instance (torchao >= 0.10.0) +quant_config = Int4WeightOnlyConfig(group_size=128) +quantization_config = TorchAoConfig(quant_type=quant_config) + +# Load and quantize the model +quantized_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B", + torch_dtype="auto", + device_map="auto", + quantization_config=quantization_config +) + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") +input_text = "What are we having for dinner?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +# auto-compile the quantized model with `cache_implementation="static"` to get speed up +output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + +## Available Quantization Schemes + +TorchAO provides a variety of quantization configurations: + +- `Int4WeightOnlyConfig` +- `Int8WeightOnlyConfig` +- `Int8DynamicActivationInt8WeightConfig` +- `Float8WeightOnlyConfig` + +Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures. + +For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py). + +> **⚠️ DEPRECATION WARNING** +> +> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release. +> +> Please use the new `AOBaseConfig`-based approach instead: +> +> ```python +> # Old way (deprecated) +> quantization_config = TorchAoConfig("int4_weight_only", group_size=128) +> +> # New way (recommended) +> from torchao.quantization import Int4WeightOnlyConfig +> quant_config = Int4WeightOnlyConfig(group_size=128) +> quantization_config = TorchAoConfig(quant_type=quant_config) +> ``` +> +> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao. +> +> ## Migration Guide +> +> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents: +> +> | Old String API | New `AOBaseConfig` API | +> |----------------|------------------------| +> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` | +> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` | +> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` | +> +> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`). + + +Below is the API for for torchao < `0.9.0` + ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer @@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer. -Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes. +Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes. > [!TIP] > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. @@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke ## Serialization -torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco. +torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao. To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals). diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 0eb4eec997..1dddaf19d4 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import re import types from typing import TYPE_CHECKING, Optional, Union @@ -27,6 +28,7 @@ if TYPE_CHECKING: from typing import Any, Dict, List from ..utils import is_torch_available, is_torchao_available, logging +from ..utils.quantization_config import TorchAoConfig if is_torch_available(): @@ -36,6 +38,21 @@ if is_torch_available(): logger = logging.get_logger(__name__) +def fuzzy_match_size(config_name: str) -> Optional[str]: + """ + Extract the size digit from strings like "4weight", "8weight". + Returns the digit as an integer if found, otherwise None. + """ + config_name = config_name.lower() + + str_match = re.search(r"(\d)weight", config_name) + + if str_match: + return str_match.group(1) + + return None + + # Finds the parent of a node module named "name" def find_parent(model, name): module_tree = name.split(".")[:-1] @@ -121,10 +138,28 @@ class TorchAoHfQuantizer(HfQuantizer): torch_dtype = torch.float32 return torch_dtype - def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): from accelerate.utils import CustomDtype + # Import AOBaseConfig directly since we know we have the right version + if self.quantization_config._get_ao_version() >= version.Version("0.10.0"): + from torchao.core.config import AOBaseConfig + + quant_type = self.quantization_config.quant_type + if isinstance(quant_type, AOBaseConfig): + # Extract size digit using fuzzy match on the class name + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + # Map the extracted digit to appropriate dtype + if size_digit == "4": + return CustomDtype.INT4 + else: + # Default to int8 + return torch.int8 + + # Original mapping for non-AOBaseConfig types map_to_target_dtype = { "int4_weight_only": CustomDtype.INT4, "int8_weight_only": torch.int8, @@ -194,14 +229,14 @@ class TorchAoHfQuantizer(HfQuantizer): from torchao.quantization import quantize_ module, tensor_name = get_module_from_name(model, param_name) - if self.pre_quantized: module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) else: + assert isinstance(self.quantization_config, TorchAoConfig) module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) - quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + quantize_(module, self.quantization_config.get_quantize_config()) def _process_model_after_weight_loading(self, model, **kwargs): """No process required for torchao quantized model""" @@ -216,7 +251,7 @@ class TorchAoHfQuantizer(HfQuantizer): return model return - def is_serializable(self, safe_serialization=None): + def is_serializable(self, safe_serialization=None) -> bool: if safe_serialization: logger.warning( "torchao quantized model does not support safe serialization, " @@ -237,7 +272,7 @@ class TorchAoHfQuantizer(HfQuantizer): return _is_torchao_serializable @property - def is_trainable(self): + def is_trainable(self) -> bool: supported_quant_types_for_training = [ "int8_weight_only", "int8_dynamic_activation_int8_weight", diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 4842970c77..4f2724f1c8 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -95,6 +95,7 @@ GGUF_MIN_VERSION = "0.10.0" XLA_FSDPV2_MIN_VERSION = "2.2.0" HQQ_MIN_VERSION = "0.2.1" VPTQ_MIN_VERSION = "0.0.4" +TORCHAO_MIN_VERSION = "0.4.0" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) @@ -191,7 +192,7 @@ _tf2onnx_available = _is_package_available("tf2onnx") _timm_available = _is_package_available("timm") _tokenizers_available = _is_package_available("tokenizers") _torchaudio_available = _is_package_available("torchaudio") -_torchao_available = _is_package_available("torchao") +_torchao_available, _torchao_version = _is_package_available("torchao", return_version=True) _torchdistx_available = _is_package_available("torchdistx") _torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True) _mlx_available = _is_package_available("mlx") @@ -1277,8 +1278,8 @@ def is_torchaudio_available(): return _torchaudio_available -def is_torchao_available(): - return _torchao_available +def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION): + return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version) def is_speech_available(): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 851249b270..152572223f 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1455,11 +1455,18 @@ class HiggsConfig(QuantizationConfigMixin): @dataclass class TorchAoConfig(QuantizationConfigMixin): + quant_method: QuantizationMethod + quant_type: Union[str, "AOBaseConfig"] # noqa: F821 + modules_to_not_convert: Optional[List] + quant_type_kwargs: Dict[str, Any] + """This is a config class for torchao quantization/sparsity techniques. Args: - quant_type (`str`): - The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`. + quant_type (`Union[str, AOBaseConfig]`): + The type of quantization we want to use. Can be either: + - A string: currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. + - An AOBaseConfig instance: for more advanced configuration options. modules_to_not_convert (`list`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. @@ -1471,9 +1478,12 @@ class TorchAoConfig(QuantizationConfigMixin): Example: ```python - from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + # AOBaseConfig-based configuration + config = Int4WeightOnlyConfig(group_size=32) + quantization_config = TorchAoConfig(config) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) - # specific quantization method + # String-based configuration quantization_config = TorchAoConfig("int4_weight_only", group_size=32) # int4_weight_only quant is only working with *torch.bfloat16* dtype right now model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) @@ -1496,105 +1506,152 @@ class TorchAoConfig(QuantizationConfigMixin): if hasattr(quantized_model, "finalize_autoquant"): print("finalizing autoquant") quantized_model.finalize_autoquant() + ``` """ - def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): + def __init__( + self, + quant_type: Union[str, "AOBaseConfig"], # noqa: F821 + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert - # when we load from serailized config, "quant_type_kwargs" will be the key - if "quant_type_kwargs" in kwargs: - self.quant_type_kwargs = kwargs["quant_type_kwargs"] - else: - self.quant_type_kwargs = kwargs - + self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs) self.post_init() + @staticmethod + def _get_ao_version() -> version.Version: + """Centralized check for TorchAO availability and version requirements.""" + if not is_torchao_available(): + raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`") + + return version.parse(importlib.metadata.version("torchao")) + def post_init(self): - r""" - Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. - """ - if is_torchao_available(): - if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"): - raise ValueError("Requires torchao 0.7.0 version and above") + """Validate configuration and set defaults.""" + ao_version = self._get_ao_version() + + # Handle quant_type based on type and version + if isinstance(self.quant_type, str): + self._validate_string_quant_type() + elif ao_version >= version.parse("0.10.0"): + from torchao.quantization.quant_api import AOBaseConfig + + if not isinstance(self.quant_type, AOBaseConfig): + raise ValueError( + f"quant_type must be either a string or an AOBaseConfig instance, got {type(self.quant_type)}" + ) else: raise ValueError( - "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. " + f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances." ) - _STR_TO_METHOD = self._get_torchao_quant_type_to_method() - if self.quant_type not in _STR_TO_METHOD.keys(): + def _validate_string_quant_type(self): + """Validate string quant_type and its kwargs.""" + methods = self._get_torchao_quant_type_to_method() + + if self.quant_type not in methods: raise ValueError( - f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer." + f"Unsupported string quantization type: {self.quant_type}. " + f"Supported types: {', '.join(methods.keys())}" ) - method = _STR_TO_METHOD[self.quant_type] + # Validate kwargs against method signature + method = methods[self.quant_type] sig = signature(method) - all_kwargs = [ + valid_kwargs = { param.name for param in sig.parameters.values() if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD] - ] - for k in self.quant_type_kwargs: - if k not in all_kwargs: - raise ValueError( - f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}" - ) + } + + invalid_kwargs = set(self.quant_type_kwargs) - valid_kwargs + if invalid_kwargs: + raise ValueError( + f"Unexpected keyword arg for {self.quant_type}: {', '.join(invalid_kwargs)}. " + f"Valid kwargs: {', '.join(valid_kwargs)}" + ) def _get_torchao_quant_type_to_method(self): - if is_torchao_available(): - from torchao.quantization import ( - autoquant, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, - ) + """Get mapping of quant_type strings to their corresponding methods.""" + from torchao.quantization import ( + autoquant, + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + ) - return { - "int4_weight_only": int4_weight_only, - "int8_weight_only": int8_weight_only, - "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, - "autoquant": autoquant, - } + return { + "int4_weight_only": int4_weight_only, + "int8_weight_only": int8_weight_only, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + "autoquant": autoquant, + } + + def get_quantize_config(self): + """Create the appropriate quantization method based on configuration.""" + if isinstance(self.quant_type, str): + methods = self._get_torchao_quant_type_to_method() + quant_type_kwargs = self.quant_type_kwargs.copy() + if ( + not torch.cuda.is_available() + and is_torchao_available() + and self.quant_type == "int4_weight_only" + and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") + ): + from torchao.dtypes import Int4CPULayout + + quant_type_kwargs["layout"] = Int4CPULayout() + + return methods[self.quant_type](**quant_type_kwargs) else: - raise ValueError( - "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" - ) + return self.quant_type - def get_apply_tensor_subclass(self): - _STR_TO_METHOD = self._get_torchao_quant_type_to_method() - quant_type_kwargs = self.quant_type_kwargs.copy() - if ( - not torch.cuda.is_available() - and is_torchao_available() - and self.quant_type == "int4_weight_only" - and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") - ): - from torchao.dtypes import Int4CPULayout - - quant_type_kwargs["layout"] = Int4CPULayout() - return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs) - - def __repr__(self): - config_dict = self.to_dict() - return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout` - dataclasses to simple dicts. - - Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ + def to_dict(self): + """Convert configuration to a dictionary.""" d = super().to_dict() - if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: - layout = d["quant_type_kwargs"]["layout"] - layout = dataclasses.asdict(layout) - d["quant_type_kwargs"]["layout"] = layout + + if isinstance(self.quant_type, str): + # Handle layout serialization if present + if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: + d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"]) + else: + # Handle AOBaseConfig serialization + from torchao.core.config import config_to_dict + + # For now we assume there is 1 config per Transfomer, however in the future + # We may want to support a config per fqn. + d["quant_type"] = {"default": config_to_dict(self.quant_type)} + return d + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """Create configuration from a dictionary.""" + ao_verison = cls._get_ao_version() + assert ao_verison >= version.parse( + "0.10.0" + ), "TorchAoConfig requires torchao >= 0.10.0 for construction from dict" + config_dict = config_dict.copy() + quant_type = config_dict.pop("quant_type") + # Check if we only have one key which is "default" + # In the future we may update this + assert ( + len(quant_type) == 1 and "default" in quant_type + ), "Expected only one key 'default' in quant_type dictionary" + quant_type = quant_type["default"] + + # Deserialize quant_type if needed + from torchao.core.config import config_from_dict + + quant_type = config_from_dict(quant_type) + + return cls(quant_type=quant_type, **config_dict) + @dataclass class BitNetConfig(QuantizationConfigMixin): diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index e6d723e678..037bf1506f 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -85,7 +85,7 @@ class TorchAoConfigTest(unittest.TestCase): Test kwargs validations in TorchAoConfig """ _ = TorchAoConfig("int4_weight_only") - with self.assertRaisesRegex(ValueError, "is not supported yet"): + with self.assertRaisesRegex(ValueError, "Unsupported string quantization type"): _ = TorchAoConfig("fp6") with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"): @@ -408,5 +408,41 @@ class TorchAoSerializationW8GPUTest(TorchAoSerializationTest): device = "cuda:0" +@require_torch_gpu +@require_torchao_version_greater_or_equal("0.10.0") +class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest): + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cuda:0" + + def setUp(self): + if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: + raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") + + from torchao.quantization import Float8WeightOnlyConfig + + self.quant_scheme = Float8WeightOnlyConfig() + self.quant_scheme_kwargs = {} + super().setUp() + + +@require_torch_gpu +@require_torchao_version_greater_or_equal("0.10.0") +class TorchAoSerializationA8W4Test(TorchAoSerializationTest): + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cuda:0" + + def setUp(self): + if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: + raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") + + from torchao.quantization import Int8DynamicActivationInt4WeightConfig + + self.quant_scheme = Int8DynamicActivationInt4WeightConfig() + self.quant_scheme_kwargs = {} + super().setUp() + + if __name__ == "__main__": unittest.main()