[bnb] Fix bnb skip modules (#24043)
* fix skip modules test * oops * address comments
This commit is contained in:
@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
|
||||
self.assertTrue(module.weight.dtype == torch.int8)
|
||||
|
||||
def test_llm_skip(self):
|
||||
r"""
|
||||
A simple test to check if `llm_int8_skip_modules` works as expected
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
|
||||
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"roberta-large-mnli", quantization_config=quantization_config
|
||||
)
|
||||
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
|
||||
self.assertTrue(
|
||||
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
|
||||
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
|
||||
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
|
||||
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
|
||||
|
||||
def test_generate_quality(self):
|
||||
r"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
|
||||
Reference in New Issue
Block a user