New HIGGS quantization interfaces, JIT kernel compilation support. (#36148)

* new flute

* new higgs working

* small adjustments

* progress and quallity

* small updates

* style

---------

Co-authored-by: Andrey Panferov <panferov.andrey3@wb.ru>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Andrei Panferov
2025-02-14 12:26:45 +01:00
committed by GitHub
parent 15ec971b8e
commit 5f726f8b8e
5 changed files with 54 additions and 80 deletions

View File

@@ -28,15 +28,12 @@ if is_torch_available():
if is_flute_available(): if is_flute_available():
import flute.utils from flute.integrations.higgs import prepare_data_transposed
from flute.tune import TuneMetaData, qgemm_v2
if is_hadamard_available(): if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform from fast_hadamard_transform import hadamard_transform
if is_flute_available():
import flute.utils
from flute.integrations.higgs import prepare_data_transposed
def pad_to_block(tensor, dims, had_block_size, value=0): def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))] pad_dims = [0 for _ in range(2 * len(tensor.shape))]
@@ -464,14 +461,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
# Quantize # Quantize
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8) codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
for i in range(0, weight.shape[0], 64): for i in range(0, weight.shape[0], 16):
codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8) codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
del weight del weight
codes = codes.reshape(codes.shape[0], -1) codes = codes.reshape(codes.shape[0], -1)
scales = scales / sqrt(hadamard_size) scales = scales / sqrt(hadamard_size)
weight, scales, tables, tables2 = prepare_data_transposed( weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
codes, codes,
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1), torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
grid.to(dtype), grid.to(dtype),
@@ -480,6 +477,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
vector_size=p, vector_size=p,
dtype=dtype, dtype=dtype,
device=device, device=device,
check_correctness=False,
) )
return { return {
@@ -487,6 +485,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
"scales": scales, "scales": scales,
"tables": tables, "tables": tables,
"tables2": tables2.view(dtype=torch.float16), "tables2": tables2.view(dtype=torch.float16),
"tune_metadata": tune_metadata,
} }
@@ -508,7 +507,6 @@ class HiggsLinear(torch.nn.Module):
self.num_bits = num_bits self.num_bits = num_bits
self.group_size = group_size self.group_size = group_size
self.hadamard_size = hadamard_size self.hadamard_size = hadamard_size
self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False)
assert in_features % group_size == 0 assert in_features % group_size == 0
assert num_bits in [2, 3, 4] assert num_bits in [2, 3, 4]
@@ -531,6 +529,7 @@ class HiggsLinear(torch.nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.workspace = None # must be set externally to be reused among layers self.workspace = None # must be set externally to be reused among layers
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
def forward(self, x): def forward(self, x):
x = pad_to_block(x, [-1], self.hadamard_size) x = pad_to_block(x, [-1], self.hadamard_size)
@@ -538,16 +537,15 @@ class HiggsLinear(torch.nn.Module):
if self.workspace is None: if self.workspace is None:
raise Exception("Workspace must be set before calling forward") raise Exception("Workspace must be set before calling forward")
return flute.qgemm_hadamard( return qgemm_v2(
x, x,
self.weight, self.weight,
self.scales, self.scales,
self.tables, self.tables,
self.tables2.view(dtype=torch.float32), self.tables2.view(dtype=torch.float32),
self.workspace, self.workspace,
self.num_bits, self.tune_metadata,
self.group_size, hadamard_size=self.hadamard_size,
self.hadamard_size,
) )

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ..utils.logging import tqdm
from .base import HfQuantizer from .base import HfQuantizer
from .quantizers_utils import get_module_from_name from .quantizers_utils import get_module_from_name
@@ -30,20 +31,6 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_num_sms_from_device(device):
target_device_cc = torch.cuda.get_device_capability(device=device)
if target_device_cc == (8, 6):
return 84
elif target_device_cc == (8, 0):
return 108
elif target_device_cc == (8, 9):
return 128
else:
raise NotImplementedError(
f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus"
)
class HiggsHfQuantizer(HfQuantizer): class HiggsHfQuantizer(HfQuantizer):
""" """
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models. Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
@@ -115,26 +102,24 @@ class HiggsHfQuantizer(HfQuantizer):
self.quantization_config.group_size, self.quantization_config.group_size,
self.quantization_config.hadamard_size, self.quantization_config.hadamard_size,
) )
del param_value del param_value
module, tensor_name = get_module_from_name(model, param_name) module, _ = get_module_from_name(model, param_name)
module_name = ".".join(param_name.split(".")[:-1])
for key, value in flute_dict.items(): for key, value in flute_dict.items():
if key in module._parameters: if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False) module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers: elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value) module._buffers[key] = torch.nn.Buffer(value)
elif key == "tune_metadata":
module.tune_metadata = value
self.quantization_config.tune_metadata[module_name] = value.to_dict()
else: else:
raise ValueError(f"Unexpected key {key} in module {module}") raise ValueError(f"Unexpected key {key} in module {module}")
if unexpected_keys is not None and param_name in unexpected_keys: if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name) unexpected_keys.remove(param_name)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
requires_grad=False,
)
def _process_model_before_weight_loading( def _process_model_before_weight_loading(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
@@ -149,57 +134,42 @@ class HiggsHfQuantizer(HfQuantizer):
model.config.quantization_config = self.quantization_config model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
import flute.utils from flute.tune import TuneMetaData, maybe_tune_and_repack
from flute.utils import make_workspace_streamk
from ..integrations import HiggsLinear from ..integrations import HiggsLinear
flute_workspaces = {} flute_workspaces = {}
for name, module in model.named_modules(): flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
if isinstance(module, HiggsLinear): for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation. # Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise. # This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces: if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk( flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
device=module.weight.device module.workspace = flute_workspaces[module.weight.device]
)
module.workspace = flute_workspaces[module.weight.device]
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors). # FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights. # If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device): module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
new_device = module.weight.device module.weight.data, module.tune_metadata = maybe_tune_and_repack(
new_num_sms = get_num_sms_from_device(new_device) weight=module.weight.data,
module.weight.data = flute.utils.pack( scales=module.scales.data,
flute.utils.unpack( metadata=module.tune_metadata,
weight=module.weight.data, )
scales=module.scales.data, self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
workspace=module.workspace,
num_bits=module.num_bits,
group_size=module.group_size,
num_sms_packed=module.num_sms_packed.item(),
).T.contiguous(),
module.num_bits,
module.group_size,
)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(new_num_sms, device=new_device, dtype=torch.int32),
requires_grad=False,
)
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import HiggsLinear from ..integrations import HiggsLinear
not_missing_keys = [] higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
for name, module in model.named_modules():
if isinstance(module, HiggsLinear): def should_update(key: str) -> bool:
for missing in missing_keys: if key.endswith(".weight") or key.endswith(".bias"):
if ( return False
(name in missing or name in f"{prefix}.{missing}") full_key = f"{prefix}.{key}"
and not missing.endswith(".weight") return any(name in key or name in full_key for name in higgs_names)
and not missing.endswith(".bias")
): return [key for key in missing_keys if not should_update(key)]
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
@property @property
def is_trainable(self, model: Optional["PreTrainedModel"] = None): def is_trainable(self, model: Optional["PreTrainedModel"] = None):

View File

@@ -639,7 +639,7 @@ def is_flax_available():
def is_flute_available(): def is_flute_available():
try: try:
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0" return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1"
except importlib.metadata.PackageNotFoundError: except importlib.metadata.PackageNotFoundError:
return False return False

View File

@@ -1404,6 +1404,8 @@ class HiggsConfig(QuantizationConfigMixin):
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization. Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
group_size (int, *optional*, defaults to 256): group_size (int, *optional*, defaults to 256):
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size. Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
tune_metadata ('dict', *optional*, defaults to {}):
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning.
""" """
def __init__( def __init__(
@@ -1413,16 +1415,20 @@ class HiggsConfig(QuantizationConfigMixin):
modules_to_not_convert: Optional[List[str]] = None, modules_to_not_convert: Optional[List[str]] = None,
hadamard_size: int = 512, hadamard_size: int = 512,
group_size: int = 256, group_size: int = 256,
tune_metadata: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
if modules_to_not_convert is None: if modules_to_not_convert is None:
modules_to_not_convert = ["lm_head"] modules_to_not_convert = ["lm_head"]
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS self.quant_method = QuantizationMethod.HIGGS
self.bits = bits self.bits = bits
self.p = p self.p = p
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
self.hadamard_size = hadamard_size self.hadamard_size = hadamard_size
self.group_size = group_size self.group_size = group_size
self.tune_metadata = tune_metadata
self.post_init() self.post_init()

View File

@@ -65,12 +65,12 @@ class HiggsConfigTest(unittest.TestCase):
@require_accelerate @require_accelerate
# @require_read_token # @require_read_token
class HiggsTest(unittest.TestCase): class HiggsTest(unittest.TestCase):
model_name = "meta-llama/Meta-Llama-3.1-8B" model_name = "unsloth/Llama-3.2-1B"
input_text = "A quick brown fox jumps over the" input_text = "Font test: A quick brown fox jumps over the"
max_new_tokens = 2 max_new_tokens = 2
EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog" EXPECTED_OUTPUT = "Font test: A quick brown fox jumps over the lazy dog"
device_map = "cuda" device_map = "cuda"