remove to restriction for 4-bit model (#33122)
* remove to restiction for 4-bit model * Update src/transformers/modeling_utils.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * bitsandbytes: prevent dtype casting while allowing device movement with .to or .cuda * quality fix * Improve warning message for .to() and .cuda() on bnb quantized models --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
@@ -256,29 +256,56 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_device_assignment(self):
|
||||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")
|
||||
|
||||
mem_before = self.model_4bit.get_memory_footprint()
|
||||
|
||||
# Move to CPU
|
||||
self.model_4bit.to("cpu")
|
||||
self.assertEqual(self.model_4bit.device.type, "cpu")
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
# Move back to CUDA device
|
||||
self.model_4bit.to(0)
|
||||
self.assertEqual(self.model_4bit.device, torch.device(0))
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_4bit.to("cpu")
|
||||
|
||||
# Moving with `to` or `cuda` is not supported with versions < 0.43.2.
|
||||
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_4bit.to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_4bit.to(torch.device("cuda:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `cuda`
|
||||
self.model_4bit.cuda()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
# Tries with a `dtype`
|
||||
self.model_4bit.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_4bit.to(torch.device("cuda:0"))
|
||||
# Tries with a `dtype` and `device`
|
||||
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
# Tries with a cast
|
||||
self.model_4bit.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
# Tries with a cast
|
||||
self.model_4bit.half()
|
||||
|
||||
# Test if we did not break anything
|
||||
@@ -287,6 +314,9 @@ class Bnb4BitTest(Base4bitTest):
|
||||
self.model_fp16 = self.model_fp16.to(torch.float32)
|
||||
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
# Check that this does not throw an error
|
||||
_ = self.model_fp16.cuda()
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.to("cpu")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user