Fix HQQ model param device transfer issue (#38466)
* Fix HQQ model param device transfer issue * modify a comment * clear the code and add test for hqq device/dtype * fix test hqq code quality of imports --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -3897,7 +3897,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
@wraps(torch.nn.Module.cuda)
|
||||
def cuda(self, *args, **kwargs):
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
# Since HQQLinear stores some tensors in the 'meta' attribute,
|
||||
# it's necessary to manually call the `cuda` method on HQQLinear layers.
|
||||
super().cuda(*args, **kwargs)
|
||||
for module in self.modules():
|
||||
if isinstance(module, HQQLinear):
|
||||
if len(args) > 0:
|
||||
device = args[0]
|
||||
else:
|
||||
device = kwargs.get("device", "cuda")
|
||||
module.cuda(device)
|
||||
return self
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
@@ -3910,8 +3923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
else:
|
||||
return super().cuda(*args, **kwargs)
|
||||
return super().cuda(*args, **kwargs)
|
||||
|
||||
@wraps(torch.nn.Module.to)
|
||||
def to(self, *args, **kwargs):
|
||||
@@ -3926,7 +3938,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
break
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.to` is not supported for HQQ-quantized models.")
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
# Since HQQLinear stores some tensors in the 'meta' attribute, we must
|
||||
# explicitly move the parameters to the target device for each HQQLinear layer after `to`.
|
||||
super().to(*args, **kwargs)
|
||||
for module in self.modules():
|
||||
if isinstance(module, HQQLinear):
|
||||
if "device" in kwargs:
|
||||
device = kwargs["device"]
|
||||
else:
|
||||
device = args[0]
|
||||
if "dtype" in kwargs:
|
||||
dtype = kwargs["dtype"]
|
||||
elif dtype_present_in_args:
|
||||
dtype = arg
|
||||
else:
|
||||
dtype = None
|
||||
# Due to the current messy implementation of HQQLinear, updating `compute_dtype`
|
||||
# followed by calling the `cuda` method achieves the intended behavior of `to`,
|
||||
# even when the target device is CPU.
|
||||
if dtype is not None:
|
||||
module.compute_dtype = dtype
|
||||
module.cuda(device)
|
||||
return self
|
||||
|
||||
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
|
||||
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
|
||||
|
||||
Reference in New Issue
Block a user