From fd6a0ade9b89c415ea213ef1aa07c9b2c32a4d75 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:56:53 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20[`Quanti?= =?UTF-8?q?zation`]=20Store=20the=20original=20dtype=20in=20the=20config?= =?UTF-8?q?=20as=20a=20private=20attribute=20=F0=9F=9A=A8=F0=9F=9A=A8?= =?UTF-8?q?=F0=9F=9A=A8=20(#26761)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First step * fix * add adjustements for gptq * change to `_pre_quantization_dtype` * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix serialization * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/configuration_utils.py | 6 +++++ src/transformers/modeling_utils.py | 27 +++++++++++++++++++++-- tests/quantization/bnb/test_4bit.py | 8 +++++++ tests/quantization/bnb/test_mixed_int8.py | 8 +++++++ tests/quantization/gptq/test_gptq.py | 20 +++++++++++++++++ 5 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c718fc5323..a23f928ecb 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -854,6 +854,9 @@ class PretrainedConfig(PushToHubMixin): else self.quantization_config ) + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = serializable_config_dict.pop("_pre_quantization_dtype", None) + self.dict_torch_dtype_to_str(serializable_config_dict) if "_flash_attn_2_enabled" in serializable_config_dict: @@ -896,6 +899,9 @@ class PretrainedConfig(PushToHubMixin): else self.quantization_config ) + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = output.pop("_pre_quantization_dtype", None) + self.dict_torch_dtype_to_str(output) return output diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9afac652f9..d567b9438a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2178,8 +2178,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" " model has already been set to the correct devices and casted to the correct `dtype`." ) - else: - return super().to(*args, **kwargs) + elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: + # For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. + # the correct API should be to load the model with the desired dtype directly through `from_pretrained`. + dtype_present_in_args = False + + if "dtype" not in kwargs: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + else: + dtype_present_in_args = True + + if dtype_present_in_args: + raise ValueError( + "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" + " `dtype` by passing the correct `torch_dtype` argument." + ) + return super().to(*args, **kwargs) def half(self, *args): # Checks if the model is quantized @@ -3165,6 +3182,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if hasattr(model, "quantization_method"): model.is_quantized = True + # We store the original dtype for quantized models as we cannot easily retrieve it + # once the weights have been quantized + # Note that once you have loaded a quantized model, you can't change its dtype so this will + # remain a single source of truth + config._pre_quantization_dtype = torch_dtype + if isinstance(device_map, str): special_dtypes = {} if load_in_8bit or load_in_4bit: diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 801173da79..a70e3d8832 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -156,6 +156,14 @@ class Bnb4BitTest(Base4bitTest): linear = get_some_linear_layer(self.model_4bit) self.assertTrue(linear.weight.__class__ == Params4bit) + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) + self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16) + def test_linear_are_4bit(self): r""" A simple test to check if the model conversion has been done correctly by checking on the diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 670be57d0c..bbd1879fb1 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -186,6 +186,14 @@ class MixedInt8Test(BaseMixedInt8Test): _ = config.to_json_string() + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue(hasattr(self.model_8bit.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) + self.assertTrue(self.model_8bit.config._pre_quantization_dtype == torch.float16) + def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the diff --git a/tests/quantization/gptq/test_gptq.py b/tests/quantization/gptq/test_gptq.py index 9139836571..4c7587f063 100644 --- a/tests/quantization/gptq/test_gptq.py +++ b/tests/quantization/gptq/test_gptq.py @@ -145,6 +145,26 @@ class GPTQTest(unittest.TestCase): self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + # This should work + _ = self.quantized_model.to(0) + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.quantized_model.to(torch.float16) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) + self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16) + def test_quantized_layers_class(self): """ Simple test to check if the model conversion has been done correctly by checking on