🚨🚨🚨 [Quantization] Store the original dtype in the config as a private attribute 🚨🚨🚨 (#26761)
* First step * fix * add adjustements for gptq * change to `_pre_quantization_dtype` * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix serialization * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -156,6 +156,14 @@ class Bnb4BitTest(Base4bitTest):
|
||||
linear = get_some_linear_layer(self.model_4bit)
|
||||
self.assertTrue(linear.weight.__class__ == Params4bit)
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
"""
|
||||
self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
|
||||
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||
self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)
|
||||
|
||||
def test_linear_are_4bit(self):
|
||||
r"""
|
||||
A simple test to check if the model conversion has been done correctly by checking on the
|
||||
|
||||
@@ -186,6 +186,14 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
_ = config.to_json_string()
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
"""
|
||||
self.assertTrue(hasattr(self.model_8bit.config, "_pre_quantization_dtype"))
|
||||
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||
self.assertTrue(self.model_8bit.config._pre_quantization_dtype == torch.float16)
|
||||
|
||||
def test_memory_footprint(self):
|
||||
r"""
|
||||
A simple test to check if the model conversion has been done correctly by checking on the
|
||||
|
||||
@@ -145,6 +145,26 @@ class GPTQTest(unittest.TestCase):
|
||||
|
||||
self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
# This should work
|
||||
_ = self.quantized_model.to(0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
self.quantized_model.to(torch.float16)
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
"""
|
||||
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
|
||||
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
|
||||
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)
|
||||
|
||||
def test_quantized_layers_class(self):
|
||||
"""
|
||||
Simple test to check if the model conversion has been done correctly by checking on
|
||||
|
||||
Reference in New Issue
Block a user