From 4795219228ee9ad63f07dc6fa769f08a13b77cb9 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 7 Jun 2023 15:27:46 +0200 Subject: [PATCH] [`bnb`] Fix bnb skip modules (#24043) * fix skip modules test * oops * address comments --- src/transformers/utils/bitsandbytes.py | 10 +++++++--- tests/bitsandbytes/test_mixed_int8.py | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index d9b4c037b4..978726be4b 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non module._parameters[tensor_name] = new_value -def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): +def _replace_with_bnb_linear( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False +): """ Private method that wraps the recursion for module replacement. Returns the converted model and a boolean that indicates if the conversion has been successfull or not. """ - has_been_replaced = False for name, module in model.named_children(): if current_key_name is None: current_key_name = [] + current_key_name.append(name) if isinstance(module, nn.Linear) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` @@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam has_been_replaced = True # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) - # Remove the last key for recursion if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_bnb_linear( module, modules_to_not_convert, current_key_name, quantization_config, + has_been_replaced=has_been_replaced, ) + # Remove the last key for recursion + current_key_name.pop(-1) return model, has_been_replaced diff --git a/tests/bitsandbytes/test_mixed_int8.py b/tests/bitsandbytes/test_mixed_int8.py index d178dbff1c..09157a251e 100644 --- a/tests/bitsandbytes/test_mixed_int8.py +++ b/tests/bitsandbytes/test_mixed_int8.py @@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test): if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules: self.assertTrue(module.weight.dtype == torch.int8) + def test_llm_skip(self): + r""" + A simple test to check if `llm_int8_skip_modules` works as expected + """ + import bitsandbytes as bnb + + quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"]) + seq_classification_model = AutoModelForSequenceClassification.from_pretrained( + "roberta-large-mnli", quantization_config=quantization_config + ) + self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8) + self.assertTrue( + isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt) + ) + + self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear)) + self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8) + self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear)) + self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8) + def test_generate_quality(self): r""" Test the generation quality of the quantized model and see that we are matching the expected output.