[Awq] Enable the possibility to skip quantization for some target modules (#27950)

* v1

* add docstring

* add tests

* add awq 0.1.8

* oops

* fix test
This commit is contained in:
Younes Belkada
2023-12-25 11:06:56 +01:00
committed by GitHub
parent 29e7a1e183
commit fa21ead73d
4 changed files with 42 additions and 1 deletions

View File

@@ -88,6 +88,7 @@ class AwqConfigTest(unittest.TestCase):
class AwqTest(unittest.TestCase):
model_name = "TheBloke/Mistral-7B-v0.1-AWQ"
dummy_transformers_model_name = "bigscience/bloom-560m"
model_with_no_k_proj_quantized = "hf-internal-testing/opt-125m-awq-no-k-proj"
input_text = "Hello my name is"
@@ -223,6 +224,24 @@ class AwqTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_no_k_proj_quantized(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
dummy_input = torch.LongTensor([[0, 1, 0]]).to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_with_no_k_proj_quantized).to(torch_device)
self.assertTrue(isinstance(quantized_model.model.decoder.layers[0].self_attn.k_proj, torch.nn.Linear))
self.assertFalse(isinstance(quantized_model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear))
EXPECTED_OUTPUT = torch.LongTensor([[0, 1, 0, 50118, 50118, 133, 248, 12, 134, 16, 10, 372, 2031]]).to(
torch_device
)
output = quantized_model.generate(dummy_input, max_new_tokens=10)
self.assertTrue((EXPECTED_OUTPUT == output).all())
@slow
@require_torch_gpu