From 5f726f8b8e1badacec5a4da1df7150671100c828 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 14 Feb 2025 12:26:45 +0100 Subject: [PATCH] New HIGGS quantization interfaces, JIT kernel compilation support. (#36148) * new flute * new higgs working * small adjustments * progress and quallity * small updates * style --------- Co-authored-by: Andrey Panferov Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- src/transformers/integrations/higgs.py | 24 +++-- .../quantizers/quantizer_higgs.py | 96 +++++++------------ src/transformers/utils/import_utils.py | 2 +- src/transformers/utils/quantization_config.py | 6 ++ tests/quantization/higgs/test_higgs.py | 6 +- 5 files changed, 54 insertions(+), 80 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 5a8f6537bb..3ba35eb4e4 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -28,15 +28,12 @@ if is_torch_available(): if is_flute_available(): - import flute.utils + from flute.integrations.higgs import prepare_data_transposed + from flute.tune import TuneMetaData, qgemm_v2 if is_hadamard_available(): from fast_hadamard_transform import hadamard_transform -if is_flute_available(): - import flute.utils - from flute.integrations.higgs import prepare_data_transposed - def pad_to_block(tensor, dims, had_block_size, value=0): pad_dims = [0 for _ in range(2 * len(tensor.shape))] @@ -464,14 +461,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256 # Quantize codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8) - for i in range(0, weight.shape[0], 64): - codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8) + for i in range(0, weight.shape[0], 16): + codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8) del weight codes = codes.reshape(codes.shape[0], -1) scales = scales / sqrt(hadamard_size) - weight, scales, tables, tables2 = prepare_data_transposed( + weight, scales, tables, tables2, tune_metadata = prepare_data_transposed( codes, torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1), grid.to(dtype), @@ -480,6 +477,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256 vector_size=p, dtype=dtype, device=device, + check_correctness=False, ) return { @@ -487,6 +485,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256 "scales": scales, "tables": tables, "tables2": tables2.view(dtype=torch.float16), + "tune_metadata": tune_metadata, } @@ -508,7 +507,6 @@ class HiggsLinear(torch.nn.Module): self.num_bits = num_bits self.group_size = group_size self.hadamard_size = hadamard_size - self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False) assert in_features % group_size == 0 assert num_bits in [2, 3, 4] @@ -531,6 +529,7 @@ class HiggsLinear(torch.nn.Module): self.register_parameter("bias", None) self.workspace = None # must be set externally to be reused among layers + self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent def forward(self, x): x = pad_to_block(x, [-1], self.hadamard_size) @@ -538,16 +537,15 @@ class HiggsLinear(torch.nn.Module): if self.workspace is None: raise Exception("Workspace must be set before calling forward") - return flute.qgemm_hadamard( + return qgemm_v2( x, self.weight, self.scales, self.tables, self.tables2.view(dtype=torch.float32), self.workspace, - self.num_bits, - self.group_size, - self.hadamard_size, + self.tune_metadata, + hadamard_size=self.hadamard_size, ) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index f33e2f21e9..83c102f16c 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING, Any, Dict, List, Optional +from ..utils.logging import tqdm from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -30,20 +31,6 @@ if is_torch_available(): logger = logging.get_logger(__name__) -def get_num_sms_from_device(device): - target_device_cc = torch.cuda.get_device_capability(device=device) - if target_device_cc == (8, 6): - return 84 - elif target_device_cc == (8, 0): - return 108 - elif target_device_cc == (8, 9): - return 128 - else: - raise NotImplementedError( - f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus" - ) - - class HiggsHfQuantizer(HfQuantizer): """ Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models. @@ -115,26 +102,24 @@ class HiggsHfQuantizer(HfQuantizer): self.quantization_config.group_size, self.quantization_config.hadamard_size, ) - del param_value - module, tensor_name = get_module_from_name(model, param_name) + module, _ = get_module_from_name(model, param_name) + module_name = ".".join(param_name.split(".")[:-1]) for key, value in flute_dict.items(): if key in module._parameters: module._parameters[key] = torch.nn.Parameter(value, requires_grad=False) elif key in module._buffers: module._buffers[key] = torch.nn.Buffer(value) + elif key == "tune_metadata": + module.tune_metadata = value + self.quantization_config.tune_metadata[module_name] = value.to_dict() else: raise ValueError(f"Unexpected key {key} in module {module}") if unexpected_keys is not None and param_name in unexpected_keys: unexpected_keys.remove(param_name) - module.num_sms_packed = torch.nn.Parameter( - torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32), - requires_grad=False, - ) - def _process_model_before_weight_loading( self, model: "PreTrainedModel", @@ -149,57 +134,42 @@ class HiggsHfQuantizer(HfQuantizer): model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - import flute.utils + from flute.tune import TuneMetaData, maybe_tune_and_repack + from flute.utils import make_workspace_streamk from ..integrations import HiggsLinear flute_workspaces = {} - for name, module in model.named_modules(): - if isinstance(module, HiggsLinear): - # Every HiggsLinear needs a "workspace": a buffer for the unpacking operation. - # This buffer needs to be on the same device as the weights, but can be reused across modules otherwise. - if module.weight.device not in flute_workspaces: - flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk( - device=module.weight.device - ) - module.workspace = flute_workspaces[module.weight.device] + flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)} + for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False): + # Every HiggsLinear needs a "workspace": a buffer for the unpacking operation. + # This buffer needs to be on the same device as the weights, but can be reused across modules otherwise. + if module.weight.device not in flute_workspaces: + flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device) + module.workspace = flute_workspaces[module.weight.device] - # FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors). - # If the model is loaded on a different device than the one it was saved on, we need to repack the weights. - if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device): - new_device = module.weight.device - new_num_sms = get_num_sms_from_device(new_device) - module.weight.data = flute.utils.pack( - flute.utils.unpack( - weight=module.weight.data, - scales=module.scales.data, - workspace=module.workspace, - num_bits=module.num_bits, - group_size=module.group_size, - num_sms_packed=module.num_sms_packed.item(), - ).T.contiguous(), - module.num_bits, - module.group_size, - ) - module.num_sms_packed = torch.nn.Parameter( - torch.tensor(new_num_sms, device=new_device, dtype=torch.int32), - requires_grad=False, - ) + # FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors). + # If the model is loaded on a different device than the one it was saved on, we need to repack the weights. + module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name]) + module.weight.data, module.tune_metadata = maybe_tune_and_repack( + weight=module.weight.data, + scales=module.scales.data, + metadata=module.tune_metadata, + ) + self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict() def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: from ..integrations import HiggsLinear - not_missing_keys = [] - for name, module in model.named_modules(): - if isinstance(module, HiggsLinear): - for missing in missing_keys: - if ( - (name in missing or name in f"{prefix}.{missing}") - and not missing.endswith(".weight") - and not missing.endswith(".bias") - ): - not_missing_keys.append(missing) - return [k for k in missing_keys if k not in not_missing_keys] + higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)} + + def should_update(key: str) -> bool: + if key.endswith(".weight") or key.endswith(".bias"): + return False + full_key = f"{prefix}.{key}" + return any(name in key or name in full_key for name in higgs_names) + + return [key for key in missing_keys if not should_update(key)] @property def is_trainable(self, model: Optional["PreTrainedModel"] = None): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bd95b6f282..aa7be764c5 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -639,7 +639,7 @@ def is_flax_available(): def is_flute_available(): try: - return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0" + return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1" except importlib.metadata.PackageNotFoundError: return False diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 11415e895d..ec8a5ef70d 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1404,6 +1404,8 @@ class HiggsConfig(QuantizationConfigMixin): Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization. group_size (int, *optional*, defaults to 256): Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size. + tune_metadata ('dict', *optional*, defaults to {}): + Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning. """ def __init__( @@ -1413,16 +1415,20 @@ class HiggsConfig(QuantizationConfigMixin): modules_to_not_convert: Optional[List[str]] = None, hadamard_size: int = 512, group_size: int = 256, + tune_metadata: Optional[Dict[str, Any]] = None, **kwargs, ): if modules_to_not_convert is None: modules_to_not_convert = ["lm_head"] + if tune_metadata is None: + tune_metadata = {} self.quant_method = QuantizationMethod.HIGGS self.bits = bits self.p = p self.modules_to_not_convert = modules_to_not_convert self.hadamard_size = hadamard_size self.group_size = group_size + self.tune_metadata = tune_metadata self.post_init() diff --git a/tests/quantization/higgs/test_higgs.py b/tests/quantization/higgs/test_higgs.py index 26ee6bc056..5c17ed63aa 100644 --- a/tests/quantization/higgs/test_higgs.py +++ b/tests/quantization/higgs/test_higgs.py @@ -65,12 +65,12 @@ class HiggsConfigTest(unittest.TestCase): @require_accelerate # @require_read_token class HiggsTest(unittest.TestCase): - model_name = "meta-llama/Meta-Llama-3.1-8B" + model_name = "unsloth/Llama-3.2-1B" - input_text = "A quick brown fox jumps over the" + input_text = "Font test: A quick brown fox jumps over the" max_new_tokens = 2 - EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog" + EXPECTED_OUTPUT = "Font test: A quick brown fox jumps over the lazy dog" device_map = "cuda"