remove to restriction for 4-bit model (#33122)

* remove to restiction for 4-bit model

* Update src/transformers/modeling_utils.py

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* bitsandbytes: prevent dtype casting while allowing device movement with .to or .cuda

* quality fix

* Improve warning message for .to() and .cuda() on bnb quantized models

---------

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
Marc Sun
2024-09-02 16:28:50 +02:00
committed by GitHub
parent 97c0f45b9c
commit 9ea1eacd11
2 changed files with 77 additions and 31 deletions

View File

@@ -256,29 +256,56 @@ class Bnb4BitTest(Base4bitTest):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_device_assignment(self):
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")
mem_before = self.model_4bit.get_memory_footprint()
# Move to CPU
self.model_4bit.to("cpu")
self.assertEqual(self.model_4bit.device.type, "cpu")
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
# Move back to CUDA device
self.model_4bit.to(0)
self.assertEqual(self.model_4bit.device, torch.device(0))
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")
# Moving with `to` or `cuda` is not supported with versions < 0.43.2.
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))
with self.assertRaises(ValueError):
# Tries with `cuda`
self.model_4bit.cuda()
with self.assertRaises(ValueError):
# Tries with a `dtype``
# Tries with a `dtype`
self.model_4bit.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))
# Tries with a `dtype` and `device`
self.model_4bit.to(device="cuda:0", dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.float()
with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.half()
# Test if we did not break anything
@@ -287,6 +314,9 @@ class Bnb4BitTest(Base4bitTest):
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
# Check that this does not throw an error
_ = self.model_fp16.cuda()
# Check this does not throw an error
_ = self.model_fp16.to("cpu")