[bnb] Fix bnb skip modules (#24043)

* fix skip modules test

* oops

* address comments
This commit is contained in:
Younes Belkada
2023-06-07 15:27:46 +02:00
committed by GitHub
parent a1160185ff
commit 4795219228
2 changed files with 27 additions and 3 deletions

View File

@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.int8)
def test_llm_skip(self):
r"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import bitsandbytes as bnb
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
self.assertTrue(
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.