[BNB] Throw ValueError when trying to cast or assign (#20409)

* `bnb` ValueError when tries to cast or assign

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove docstrings

* change error log

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2022-11-23 15:51:50 +01:00
committed by GitHub
parent 03ae1f060b
commit ad654e4484
2 changed files with 70 additions and 0 deletions

View File

@@ -115,6 +115,46 @@ class MixedInt8Test(BaseMixedInt8Test):
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_8bit.to("cpu")
with self.assertRaises(ValueError):
# Tries with a `dtype``
self.model_8bit.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.to(torch.device("cuda:0"))
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.float()
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.half()
# Test if we did not break anything
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
# Check this does not throw an error
_ = self.model_fp16.to("cpu")
# Check this does not throw an error
_ = self.model_fp16.half()
# Check this does not throw an error
_ = self.model_fp16.float()
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):