From 9ea1eacd11b10acedd489c9a44c4ae45a358508d Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:28:50 +0200 Subject: [PATCH] remove to restriction for 4-bit model (#33122) * remove to restiction for 4-bit model * Update src/transformers/modeling_utils.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * bitsandbytes: prevent dtype casting while allowing device movement with .to or .cuda * quality fix * Improve warning message for .to() and .cuda() on bnb quantized models --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> --- src/transformers/modeling_utils.py | 60 ++++++++++++++++++----------- tests/quantization/bnb/test_4bit.py | 48 ++++++++++++++++++----- 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b943b5e798..f931a6af3e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2861,38 +2861,54 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix def cuda(self, *args, **kwargs): if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: raise ValueError("`.cuda` is not supported for HQQ-quantized models.") - # Checks if the model has been loaded in 8-bit + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." - ) + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) else: return super().cuda(*args, **kwargs) @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): + # For BNB/GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. + # the correct API should be to load the model with the desired dtype directly through `from_pretrained`. + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: raise ValueError("`.to` is not supported for HQQ-quantized models.") - # Checks if the model has been loaded in 8-bit + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." - ) + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: - # For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. - # the correct API should be to load the model with the desired dtype directly through `from_pretrained`. - dtype_present_in_args = False - - if "dtype" not in kwargs: - for arg in args: - if isinstance(arg, torch.dtype): - dtype_present_in_args = True - break - else: - dtype_present_in_args = True - if dtype_present_in_args: raise ValueError( "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 71a2d7c815..785402b3f7 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -256,29 +256,56 @@ class Bnb4BitTest(Base4bitTest): self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + def test_device_assignment(self): + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + self.skipTest(reason="This test requires bitsandbytes >= 0.43.2") + + mem_before = self.model_4bit.get_memory_footprint() + + # Move to CPU + self.model_4bit.to("cpu") + self.assertEqual(self.model_4bit.device.type, "cpu") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + + # Move back to CUDA device + self.model_4bit.to(0) + self.assertEqual(self.model_4bit.device, torch.device(0)) + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. Checks also if other models are casted correctly. """ - with self.assertRaises(ValueError): - # Tries with `str` - self.model_4bit.to("cpu") + + # Moving with `to` or `cuda` is not supported with versions < 0.43.2. + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + with self.assertRaises(ValueError): + # Tries with `str` + self.model_4bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_4bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with `cuda` + self.model_4bit.cuda() with self.assertRaises(ValueError): - # Tries with a `dtype`` + # Tries with a `dtype` self.model_4bit.to(torch.float16) with self.assertRaises(ValueError): - # Tries with a `device` - self.model_4bit.to(torch.device("cuda:0")) + # Tries with a `dtype` and `device` + self.model_4bit.to(device="cuda:0", dtype=torch.float16) with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a cast self.model_4bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a cast self.model_4bit.half() # Test if we did not break anything @@ -287,6 +314,9 @@ class Bnb4BitTest(Base4bitTest): self.model_fp16 = self.model_fp16.to(torch.float32) _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + # Check that this does not throw an error + _ = self.model_fp16.cuda() + # Check this does not throw an error _ = self.model_fp16.to("cpu")