New HIGGS quantization interfaces, JIT kernel compilation support. (#36148)
* new flute * new higgs working * small adjustments * progress and quallity * small updates * style --------- Co-authored-by: Andrey Panferov <panferov.andrey3@wb.ru> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -28,15 +28,12 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_flute_available():
|
if is_flute_available():
|
||||||
import flute.utils
|
from flute.integrations.higgs import prepare_data_transposed
|
||||||
|
from flute.tune import TuneMetaData, qgemm_v2
|
||||||
|
|
||||||
if is_hadamard_available():
|
if is_hadamard_available():
|
||||||
from fast_hadamard_transform import hadamard_transform
|
from fast_hadamard_transform import hadamard_transform
|
||||||
|
|
||||||
if is_flute_available():
|
|
||||||
import flute.utils
|
|
||||||
from flute.integrations.higgs import prepare_data_transposed
|
|
||||||
|
|
||||||
|
|
||||||
def pad_to_block(tensor, dims, had_block_size, value=0):
|
def pad_to_block(tensor, dims, had_block_size, value=0):
|
||||||
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
|
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
|
||||||
@@ -464,14 +461,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
|
|||||||
|
|
||||||
# Quantize
|
# Quantize
|
||||||
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
|
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
|
||||||
for i in range(0, weight.shape[0], 64):
|
for i in range(0, weight.shape[0], 16):
|
||||||
codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
|
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
codes = codes.reshape(codes.shape[0], -1)
|
codes = codes.reshape(codes.shape[0], -1)
|
||||||
scales = scales / sqrt(hadamard_size)
|
scales = scales / sqrt(hadamard_size)
|
||||||
|
|
||||||
weight, scales, tables, tables2 = prepare_data_transposed(
|
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
|
||||||
codes,
|
codes,
|
||||||
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
|
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
|
||||||
grid.to(dtype),
|
grid.to(dtype),
|
||||||
@@ -480,6 +477,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
|
|||||||
vector_size=p,
|
vector_size=p,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
check_correctness=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -487,6 +485,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
|
|||||||
"scales": scales,
|
"scales": scales,
|
||||||
"tables": tables,
|
"tables": tables,
|
||||||
"tables2": tables2.view(dtype=torch.float16),
|
"tables2": tables2.view(dtype=torch.float16),
|
||||||
|
"tune_metadata": tune_metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -508,7 +507,6 @@ class HiggsLinear(torch.nn.Module):
|
|||||||
self.num_bits = num_bits
|
self.num_bits = num_bits
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.hadamard_size = hadamard_size
|
self.hadamard_size = hadamard_size
|
||||||
self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False)
|
|
||||||
|
|
||||||
assert in_features % group_size == 0
|
assert in_features % group_size == 0
|
||||||
assert num_bits in [2, 3, 4]
|
assert num_bits in [2, 3, 4]
|
||||||
@@ -531,6 +529,7 @@ class HiggsLinear(torch.nn.Module):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.workspace = None # must be set externally to be reused among layers
|
self.workspace = None # must be set externally to be reused among layers
|
||||||
|
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = pad_to_block(x, [-1], self.hadamard_size)
|
x = pad_to_block(x, [-1], self.hadamard_size)
|
||||||
@@ -538,16 +537,15 @@ class HiggsLinear(torch.nn.Module):
|
|||||||
if self.workspace is None:
|
if self.workspace is None:
|
||||||
raise Exception("Workspace must be set before calling forward")
|
raise Exception("Workspace must be set before calling forward")
|
||||||
|
|
||||||
return flute.qgemm_hadamard(
|
return qgemm_v2(
|
||||||
x,
|
x,
|
||||||
self.weight,
|
self.weight,
|
||||||
self.scales,
|
self.scales,
|
||||||
self.tables,
|
self.tables,
|
||||||
self.tables2.view(dtype=torch.float32),
|
self.tables2.view(dtype=torch.float32),
|
||||||
self.workspace,
|
self.workspace,
|
||||||
self.num_bits,
|
self.tune_metadata,
|
||||||
self.group_size,
|
hadamard_size=self.hadamard_size,
|
||||||
self.hadamard_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from ..utils.logging import tqdm
|
||||||
from .base import HfQuantizer
|
from .base import HfQuantizer
|
||||||
from .quantizers_utils import get_module_from_name
|
from .quantizers_utils import get_module_from_name
|
||||||
|
|
||||||
@@ -30,20 +31,6 @@ if is_torch_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_num_sms_from_device(device):
|
|
||||||
target_device_cc = torch.cuda.get_device_capability(device=device)
|
|
||||||
if target_device_cc == (8, 6):
|
|
||||||
return 84
|
|
||||||
elif target_device_cc == (8, 0):
|
|
||||||
return 108
|
|
||||||
elif target_device_cc == (8, 9):
|
|
||||||
return 128
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HiggsHfQuantizer(HfQuantizer):
|
class HiggsHfQuantizer(HfQuantizer):
|
||||||
"""
|
"""
|
||||||
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
|
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
|
||||||
@@ -115,26 +102,24 @@ class HiggsHfQuantizer(HfQuantizer):
|
|||||||
self.quantization_config.group_size,
|
self.quantization_config.group_size,
|
||||||
self.quantization_config.hadamard_size,
|
self.quantization_config.hadamard_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
del param_value
|
del param_value
|
||||||
|
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, _ = get_module_from_name(model, param_name)
|
||||||
|
module_name = ".".join(param_name.split(".")[:-1])
|
||||||
for key, value in flute_dict.items():
|
for key, value in flute_dict.items():
|
||||||
if key in module._parameters:
|
if key in module._parameters:
|
||||||
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
|
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
|
||||||
elif key in module._buffers:
|
elif key in module._buffers:
|
||||||
module._buffers[key] = torch.nn.Buffer(value)
|
module._buffers[key] = torch.nn.Buffer(value)
|
||||||
|
elif key == "tune_metadata":
|
||||||
|
module.tune_metadata = value
|
||||||
|
self.quantization_config.tune_metadata[module_name] = value.to_dict()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected key {key} in module {module}")
|
raise ValueError(f"Unexpected key {key} in module {module}")
|
||||||
|
|
||||||
if unexpected_keys is not None and param_name in unexpected_keys:
|
if unexpected_keys is not None and param_name in unexpected_keys:
|
||||||
unexpected_keys.remove(param_name)
|
unexpected_keys.remove(param_name)
|
||||||
|
|
||||||
module.num_sms_packed = torch.nn.Parameter(
|
|
||||||
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_model_before_weight_loading(
|
def _process_model_before_weight_loading(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
@@ -149,57 +134,42 @@ class HiggsHfQuantizer(HfQuantizer):
|
|||||||
model.config.quantization_config = self.quantization_config
|
model.config.quantization_config = self.quantization_config
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||||
import flute.utils
|
from flute.tune import TuneMetaData, maybe_tune_and_repack
|
||||||
|
from flute.utils import make_workspace_streamk
|
||||||
|
|
||||||
from ..integrations import HiggsLinear
|
from ..integrations import HiggsLinear
|
||||||
|
|
||||||
flute_workspaces = {}
|
flute_workspaces = {}
|
||||||
for name, module in model.named_modules():
|
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
|
||||||
if isinstance(module, HiggsLinear):
|
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
|
||||||
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
|
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
|
||||||
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
|
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
|
||||||
if module.weight.device not in flute_workspaces:
|
if module.weight.device not in flute_workspaces:
|
||||||
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
|
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
|
||||||
device=module.weight.device
|
module.workspace = flute_workspaces[module.weight.device]
|
||||||
)
|
|
||||||
module.workspace = flute_workspaces[module.weight.device]
|
|
||||||
|
|
||||||
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
|
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
|
||||||
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
|
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
|
||||||
if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device):
|
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
|
||||||
new_device = module.weight.device
|
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
|
||||||
new_num_sms = get_num_sms_from_device(new_device)
|
weight=module.weight.data,
|
||||||
module.weight.data = flute.utils.pack(
|
scales=module.scales.data,
|
||||||
flute.utils.unpack(
|
metadata=module.tune_metadata,
|
||||||
weight=module.weight.data,
|
)
|
||||||
scales=module.scales.data,
|
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
|
||||||
workspace=module.workspace,
|
|
||||||
num_bits=module.num_bits,
|
|
||||||
group_size=module.group_size,
|
|
||||||
num_sms_packed=module.num_sms_packed.item(),
|
|
||||||
).T.contiguous(),
|
|
||||||
module.num_bits,
|
|
||||||
module.group_size,
|
|
||||||
)
|
|
||||||
module.num_sms_packed = torch.nn.Parameter(
|
|
||||||
torch.tensor(new_num_sms, device=new_device, dtype=torch.int32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||||
from ..integrations import HiggsLinear
|
from ..integrations import HiggsLinear
|
||||||
|
|
||||||
not_missing_keys = []
|
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, HiggsLinear):
|
def should_update(key: str) -> bool:
|
||||||
for missing in missing_keys:
|
if key.endswith(".weight") or key.endswith(".bias"):
|
||||||
if (
|
return False
|
||||||
(name in missing or name in f"{prefix}.{missing}")
|
full_key = f"{prefix}.{key}"
|
||||||
and not missing.endswith(".weight")
|
return any(name in key or name in full_key for name in higgs_names)
|
||||||
and not missing.endswith(".bias")
|
|
||||||
):
|
return [key for key in missing_keys if not should_update(key)]
|
||||||
not_missing_keys.append(missing)
|
|
||||||
return [k for k in missing_keys if k not in not_missing_keys]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||||
|
|||||||
@@ -639,7 +639,7 @@ def is_flax_available():
|
|||||||
|
|
||||||
def is_flute_available():
|
def is_flute_available():
|
||||||
try:
|
try:
|
||||||
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0"
|
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1"
|
||||||
except importlib.metadata.PackageNotFoundError:
|
except importlib.metadata.PackageNotFoundError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -1404,6 +1404,8 @@ class HiggsConfig(QuantizationConfigMixin):
|
|||||||
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
|
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
|
||||||
group_size (int, *optional*, defaults to 256):
|
group_size (int, *optional*, defaults to 256):
|
||||||
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
|
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
|
||||||
|
tune_metadata ('dict', *optional*, defaults to {}):
|
||||||
|
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1413,16 +1415,20 @@ class HiggsConfig(QuantizationConfigMixin):
|
|||||||
modules_to_not_convert: Optional[List[str]] = None,
|
modules_to_not_convert: Optional[List[str]] = None,
|
||||||
hadamard_size: int = 512,
|
hadamard_size: int = 512,
|
||||||
group_size: int = 256,
|
group_size: int = 256,
|
||||||
|
tune_metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if modules_to_not_convert is None:
|
if modules_to_not_convert is None:
|
||||||
modules_to_not_convert = ["lm_head"]
|
modules_to_not_convert = ["lm_head"]
|
||||||
|
if tune_metadata is None:
|
||||||
|
tune_metadata = {}
|
||||||
self.quant_method = QuantizationMethod.HIGGS
|
self.quant_method = QuantizationMethod.HIGGS
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.p = p
|
self.p = p
|
||||||
self.modules_to_not_convert = modules_to_not_convert
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
self.hadamard_size = hadamard_size
|
self.hadamard_size = hadamard_size
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
self.tune_metadata = tune_metadata
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
|||||||
@@ -65,12 +65,12 @@ class HiggsConfigTest(unittest.TestCase):
|
|||||||
@require_accelerate
|
@require_accelerate
|
||||||
# @require_read_token
|
# @require_read_token
|
||||||
class HiggsTest(unittest.TestCase):
|
class HiggsTest(unittest.TestCase):
|
||||||
model_name = "meta-llama/Meta-Llama-3.1-8B"
|
model_name = "unsloth/Llama-3.2-1B"
|
||||||
|
|
||||||
input_text = "A quick brown fox jumps over the"
|
input_text = "Font test: A quick brown fox jumps over the"
|
||||||
max_new_tokens = 2
|
max_new_tokens = 2
|
||||||
|
|
||||||
EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog"
|
EXPECTED_OUTPUT = "Font test: A quick brown fox jumps over the lazy dog"
|
||||||
|
|
||||||
device_map = "cuda"
|
device_map = "cuda"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user