FIX / Quantization: Add extra validation for bnb config (#31135)
add validation for bnb config
This commit is contained in:
@@ -383,6 +383,10 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
|||||||
if bnb_4bit_quant_storage is None:
|
if bnb_4bit_quant_storage is None:
|
||||||
self.bnb_4bit_quant_storage = torch.uint8
|
self.bnb_4bit_quant_storage = torch.uint8
|
||||||
elif isinstance(bnb_4bit_quant_storage, str):
|
elif isinstance(bnb_4bit_quant_storage, str):
|
||||||
|
if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
|
||||||
|
raise ValueError(
|
||||||
|
"`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
|
||||||
|
)
|
||||||
self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
|
self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
|
||||||
elif isinstance(bnb_4bit_quant_storage, torch.dtype):
|
elif isinstance(bnb_4bit_quant_storage, torch.dtype):
|
||||||
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
||||||
|
|||||||
@@ -303,6 +303,13 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small", load_in_4bit=True, device_map="auto")
|
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small", load_in_4bit=True, device_map="auto")
|
||||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||||
|
|
||||||
|
def test_bnb_4bit_wrong_config(self):
|
||||||
|
r"""
|
||||||
|
Test whether creating a bnb config with unsupported values leads to errors.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
|
||||||
|
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
|
|||||||
Reference in New Issue
Block a user