New HIGGS quantization interfaces, JIT kernel compilation support. (#36148)

* new flute

* new higgs working

* small adjustments

* progress and quallity

* small updates

* style

---------

Co-authored-by: Andrey Panferov <panferov.andrey3@wb.ru>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Andrei Panferov
2025-02-14 12:26:45 +01:00
committed by GitHub
parent 15ec971b8e
commit 5f726f8b8e
5 changed files with 54 additions and 80 deletions

View File

@@ -28,15 +28,12 @@ if is_torch_available():
if is_flute_available():
import flute.utils
from flute.integrations.higgs import prepare_data_transposed
from flute.tune import TuneMetaData, qgemm_v2
if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform
if is_flute_available():
import flute.utils
from flute.integrations.higgs import prepare_data_transposed
def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
@@ -464,14 +461,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
# Quantize
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
for i in range(0, weight.shape[0], 64):
codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
for i in range(0, weight.shape[0], 16):
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
del weight
codes = codes.reshape(codes.shape[0], -1)
scales = scales / sqrt(hadamard_size)
weight, scales, tables, tables2 = prepare_data_transposed(
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
codes,
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
grid.to(dtype),
@@ -480,6 +477,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
vector_size=p,
dtype=dtype,
device=device,
check_correctness=False,
)
return {
@@ -487,6 +485,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
"scales": scales,
"tables": tables,
"tables2": tables2.view(dtype=torch.float16),
"tune_metadata": tune_metadata,
}
@@ -508,7 +507,6 @@ class HiggsLinear(torch.nn.Module):
self.num_bits = num_bits
self.group_size = group_size
self.hadamard_size = hadamard_size
self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False)
assert in_features % group_size == 0
assert num_bits in [2, 3, 4]
@@ -531,6 +529,7 @@ class HiggsLinear(torch.nn.Module):
self.register_parameter("bias", None)
self.workspace = None # must be set externally to be reused among layers
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
def forward(self, x):
x = pad_to_block(x, [-1], self.hadamard_size)
@@ -538,16 +537,15 @@ class HiggsLinear(torch.nn.Module):
if self.workspace is None:
raise Exception("Workspace must be set before calling forward")
return flute.qgemm_hadamard(
return qgemm_v2(
x,
self.weight,
self.scales,
self.tables,
self.tables2.view(dtype=torch.float32),
self.workspace,
self.num_bits,
self.group_size,
self.hadamard_size,
self.tune_metadata,
hadamard_size=self.hadamard_size,
)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ..utils.logging import tqdm
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
@@ -30,20 +31,6 @@ if is_torch_available():
logger = logging.get_logger(__name__)
def get_num_sms_from_device(device):
target_device_cc = torch.cuda.get_device_capability(device=device)
if target_device_cc == (8, 6):
return 84
elif target_device_cc == (8, 0):
return 108
elif target_device_cc == (8, 9):
return 128
else:
raise NotImplementedError(
f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus"
)
class HiggsHfQuantizer(HfQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
@@ -115,26 +102,24 @@ class HiggsHfQuantizer(HfQuantizer):
self.quantization_config.group_size,
self.quantization_config.hadamard_size,
)
del param_value
module, tensor_name = get_module_from_name(model, param_name)
module, _ = get_module_from_name(model, param_name)
module_name = ".".join(param_name.split(".")[:-1])
for key, value in flute_dict.items():
if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value)
elif key == "tune_metadata":
module.tune_metadata = value
self.quantization_config.tune_metadata[module_name] = value.to_dict()
else:
raise ValueError(f"Unexpected key {key} in module {module}")
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
requires_grad=False,
)
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
@@ -149,57 +134,42 @@ class HiggsHfQuantizer(HfQuantizer):
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
import flute.utils
from flute.tune import TuneMetaData, maybe_tune_and_repack
from flute.utils import make_workspace_streamk
from ..integrations import HiggsLinear
flute_workspaces = {}
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
device=module.weight.device
)
module.workspace = flute_workspaces[module.weight.device]
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
module.workspace = flute_workspaces[module.weight.device]
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device):
new_device = module.weight.device
new_num_sms = get_num_sms_from_device(new_device)
module.weight.data = flute.utils.pack(
flute.utils.unpack(
weight=module.weight.data,
scales=module.scales.data,
workspace=module.workspace,
num_bits=module.num_bits,
group_size=module.group_size,
num_sms_packed=module.num_sms_packed.item(),
).T.contiguous(),
module.num_bits,
module.group_size,
)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(new_num_sms, device=new_device, dtype=torch.int32),
requires_grad=False,
)
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
weight=module.weight.data,
scales=module.scales.data,
metadata=module.tune_metadata,
)
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import HiggsLinear
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
def should_update(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"):
return False
full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in higgs_names)
return [key for key in missing_keys if not should_update(key)]
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):

View File

@@ -639,7 +639,7 @@ def is_flax_available():
def is_flute_available():
try:
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0"
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1"
except importlib.metadata.PackageNotFoundError:
return False

View File

@@ -1404,6 +1404,8 @@ class HiggsConfig(QuantizationConfigMixin):
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
group_size (int, *optional*, defaults to 256):
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
tune_metadata ('dict', *optional*, defaults to {}):
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning.
"""
def __init__(
@@ -1413,16 +1415,20 @@ class HiggsConfig(QuantizationConfigMixin):
modules_to_not_convert: Optional[List[str]] = None,
hadamard_size: int = 512,
group_size: int = 256,
tune_metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
if modules_to_not_convert is None:
modules_to_not_convert = ["lm_head"]
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS
self.bits = bits
self.p = p
self.modules_to_not_convert = modules_to_not_convert
self.hadamard_size = hadamard_size
self.group_size = group_size
self.tune_metadata = tune_metadata
self.post_init()

View File

@@ -65,12 +65,12 @@ class HiggsConfigTest(unittest.TestCase):
@require_accelerate
# @require_read_token
class HiggsTest(unittest.TestCase):
model_name = "meta-llama/Meta-Llama-3.1-8B"
model_name = "unsloth/Llama-3.2-1B"
input_text = "A quick brown fox jumps over the"
input_text = "Font test: A quick brown fox jumps over the"
max_new_tokens = 2
EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog"
EXPECTED_OUTPUT = "Font test: A quick brown fox jumps over the lazy dog"
device_map = "cuda"