[bnb] Fix bnb skip modules (#24043)

* fix skip modules test

* oops

* address comments
This commit is contained in:
Younes Belkada
2023-06-07 15:27:46 +02:00
committed by GitHub
parent a1160185ff
commit 4795219228
2 changed files with 27 additions and 3 deletions

View File

@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
""" """
Private method that wraps the recursion for module replacement. Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not. Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
""" """
has_been_replaced = False
for name, module in model.named_children(): for name, module in model.named_children():
if current_key_name is None: if current_key_name is None:
current_key_name = [] current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced = True has_been_replaced = True
# Force requires grad to False to avoid unexpected errors # Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear( _, has_been_replaced = _replace_with_bnb_linear(
module, module,
modules_to_not_convert, modules_to_not_convert,
current_key_name, current_key_name,
quantization_config, quantization_config,
has_been_replaced=has_been_replaced,
) )
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced return model, has_been_replaced

View File

@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules: if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.int8) self.assertTrue(module.weight.dtype == torch.int8)
def test_llm_skip(self):
r"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import bitsandbytes as bnb
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
self.assertTrue(
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
def test_generate_quality(self): def test_generate_quality(self):
r""" r"""
Test the generation quality of the quantized model and see that we are matching the expected output. Test the generation quality of the quantized model and see that we are matching the expected output.