🚨🚨🚨 [Quantization] Store the original dtype in the config as a private attribute 🚨🚨🚨 (#26761)
* First step * fix * add adjustements for gptq * change to `_pre_quantization_dtype` * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix serialization * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -854,6 +854,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
else self.quantization_config
|
else self.quantization_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||||
|
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
|
||||||
|
|
||||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||||
|
|
||||||
if "_flash_attn_2_enabled" in serializable_config_dict:
|
if "_flash_attn_2_enabled" in serializable_config_dict:
|
||||||
@@ -896,6 +899,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
else self.quantization_config
|
else self.quantization_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||||
|
_ = output.pop("_pre_quantization_dtype", None)
|
||||||
|
|
||||||
self.dict_torch_dtype_to_str(output)
|
self.dict_torch_dtype_to_str(output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -2178,7 +2178,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||||
)
|
)
|
||||||
|
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
|
||||||
|
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
|
||||||
|
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
|
||||||
|
dtype_present_in_args = False
|
||||||
|
|
||||||
|
if "dtype" not in kwargs:
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.dtype):
|
||||||
|
dtype_present_in_args = True
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
|
dtype_present_in_args = True
|
||||||
|
|
||||||
|
if dtype_present_in_args:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
|
||||||
|
" `dtype` by passing the correct `torch_dtype` argument."
|
||||||
|
)
|
||||||
return super().to(*args, **kwargs)
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
def half(self, *args):
|
def half(self, *args):
|
||||||
@@ -3165,6 +3182,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if hasattr(model, "quantization_method"):
|
if hasattr(model, "quantization_method"):
|
||||||
model.is_quantized = True
|
model.is_quantized = True
|
||||||
|
|
||||||
|
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||||
|
# once the weights have been quantized
|
||||||
|
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||||
|
# remain a single source of truth
|
||||||
|
config._pre_quantization_dtype = torch_dtype
|
||||||
|
|
||||||
if isinstance(device_map, str):
|
if isinstance(device_map, str):
|
||||||
special_dtypes = {}
|
special_dtypes = {}
|
||||||
if load_in_8bit or load_in_4bit:
|
if load_in_8bit or load_in_4bit:
|
||||||
|
|||||||
@@ -156,6 +156,14 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
linear = get_some_linear_layer(self.model_4bit)
|
linear = get_some_linear_layer(self.model_4bit)
|
||||||
self.assertTrue(linear.weight.__class__ == Params4bit)
|
self.assertTrue(linear.weight.__class__ == Params4bit)
|
||||||
|
|
||||||
|
def test_original_dtype(self):
|
||||||
|
r"""
|
||||||
|
A simple test to check if the model succesfully stores the original dtype
|
||||||
|
"""
|
||||||
|
self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)
|
||||||
|
|
||||||
def test_linear_are_4bit(self):
|
def test_linear_are_4bit(self):
|
||||||
r"""
|
r"""
|
||||||
A simple test to check if the model conversion has been done correctly by checking on the
|
A simple test to check if the model conversion has been done correctly by checking on the
|
||||||
|
|||||||
@@ -186,6 +186,14 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
_ = config.to_json_string()
|
_ = config.to_json_string()
|
||||||
|
|
||||||
|
def test_original_dtype(self):
|
||||||
|
r"""
|
||||||
|
A simple test to check if the model succesfully stores the original dtype
|
||||||
|
"""
|
||||||
|
self.assertTrue(hasattr(self.model_8bit.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertTrue(self.model_8bit.config._pre_quantization_dtype == torch.float16)
|
||||||
|
|
||||||
def test_memory_footprint(self):
|
def test_memory_footprint(self):
|
||||||
r"""
|
r"""
|
||||||
A simple test to check if the model conversion has been done correctly by checking on the
|
A simple test to check if the model conversion has been done correctly by checking on the
|
||||||
|
|||||||
@@ -145,6 +145,26 @@ class GPTQTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)
|
self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||||
|
|
||||||
|
def test_device_and_dtype_assignment(self):
|
||||||
|
r"""
|
||||||
|
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||||
|
Checks also if other models are casted correctly.
|
||||||
|
"""
|
||||||
|
# This should work
|
||||||
|
_ = self.quantized_model.to(0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Tries with a `dtype``
|
||||||
|
self.quantized_model.to(torch.float16)
|
||||||
|
|
||||||
|
def test_original_dtype(self):
|
||||||
|
r"""
|
||||||
|
A simple test to check if the model succesfully stores the original dtype
|
||||||
|
"""
|
||||||
|
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||||
|
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)
|
||||||
|
|
||||||
def test_quantized_layers_class(self):
|
def test_quantized_layers_class(self):
|
||||||
"""
|
"""
|
||||||
Simple test to check if the model conversion has been done correctly by checking on
|
Simple test to check if the model conversion has been done correctly by checking on
|
||||||
|
|||||||
Reference in New Issue
Block a user