[awq] replace scale when we have GELU (#30074)
* fix awq test * style * add log * new fix * style * only modifying impacted model in the end * rename function
This commit is contained in:
@@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_auto_awq
|
||||
@require_accelerate
|
||||
class AwqScaleTest(unittest.TestCase):
|
||||
model_name = "TechxGenus/starcoder2-3b-AWQ"
|
||||
|
||||
def test_load_quantized_model(self):
|
||||
from awq.modules.act import ScaledActivation
|
||||
|
||||
"""
|
||||
Simple test that checks if the scales have been replaced in the quantized model
|
||||
"""
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
|
||||
|
||||
Reference in New Issue
Block a user