remove to restriction for 4-bit model (#33122)
* remove to restiction for 4-bit model * Update src/transformers/modeling_utils.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * bitsandbytes: prevent dtype casting while allowing device movement with .to or .cuda * quality fix * Improve warning message for .to() and .cuda() on bnb quantized models --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
@@ -2861,38 +2861,54 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
def cuda(self, *args, **kwargs):
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
)
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models. "
|
||||
" Please use the model as it is, since the model has already been set to the correct devices."
|
||||
)
|
||||
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
else:
|
||||
return super().cuda(*args, **kwargs)
|
||||
|
||||
@wraps(torch.nn.Module.to)
|
||||
def to(self, *args, **kwargs):
|
||||
# For BNB/GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
|
||||
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
|
||||
dtype_present_in_args = "dtype" in kwargs
|
||||
|
||||
if not dtype_present_in_args:
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
dtype_present_in_args = True
|
||||
break
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.to` is not supported for HQQ-quantized models.")
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
)
|
||||
if dtype_present_in_args:
|
||||
raise ValueError(
|
||||
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
|
||||
" desired `dtype` by passing the correct `torch_dtype` argument."
|
||||
)
|
||||
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
)
|
||||
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
|
||||
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
|
||||
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
|
||||
dtype_present_in_args = False
|
||||
|
||||
if "dtype" not in kwargs:
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
dtype_present_in_args = True
|
||||
break
|
||||
else:
|
||||
dtype_present_in_args = True
|
||||
|
||||
if dtype_present_in_args:
|
||||
raise ValueError(
|
||||
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
|
||||
|
||||
@@ -256,29 +256,56 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_device_assignment(self):
|
||||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")
|
||||
|
||||
mem_before = self.model_4bit.get_memory_footprint()
|
||||
|
||||
# Move to CPU
|
||||
self.model_4bit.to("cpu")
|
||||
self.assertEqual(self.model_4bit.device.type, "cpu")
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
# Move back to CUDA device
|
||||
self.model_4bit.to(0)
|
||||
self.assertEqual(self.model_4bit.device, torch.device(0))
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_4bit.to("cpu")
|
||||
|
||||
# Moving with `to` or `cuda` is not supported with versions < 0.43.2.
|
||||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_4bit.to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_4bit.to(torch.device("cuda:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `cuda`
|
||||
self.model_4bit.cuda()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
# Tries with a `dtype`
|
||||
self.model_4bit.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_4bit.to(torch.device("cuda:0"))
|
||||
# Tries with a `dtype` and `device`
|
||||
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
# Tries with a cast
|
||||
self.model_4bit.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
# Tries with a cast
|
||||
self.model_4bit.half()
|
||||
|
||||
# Test if we did not break anything
|
||||
@@ -287,6 +314,9 @@ class Bnb4BitTest(Base4bitTest):
|
||||
self.model_fp16 = self.model_fp16.to(torch.float32)
|
||||
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
# Check that this does not throw an error
|
||||
_ = self.model_fp16.cuda()
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.to("cpu")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user