Faster generation using AWQ + Fused modules (#27411)
* v1 fusing modules * add fused mlp support * up * fix CI * block save_pretrained * fixup * small fix * add new condition * add v1 docs * add some comments * style * fix nit * adapt from suggestion * add check * change arg names * change variables name * Update src/transformers/integrations/awq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * style * split up into 3 different private methods * more conditions * more checks * add fused tests for custom models * fix * fix tests * final update docs * final fixes * fix importlib metadata * Update src/transformers/utils/quantization_config.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change it to `do_fuse` * nit * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * few fixes * revert * fix test * fix copies * raise error if model is not quantized * add test * use quantization_config.config when fusing * Update src/transformers/modeling_utils.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -107,6 +108,11 @@ class AwqTest(unittest.TestCase):
|
||||
device_map=cls.device_map,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
@@ -158,6 +164,13 @@ class AwqTest(unittest.TestCase):
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_raise_if_non_quantized(self):
|
||||
model_id = "facebook/opt-125m"
|
||||
quantization_config = AwqConfig(bits=4)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||
|
||||
def test_quantized_model_bf16(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with bf16
|
||||
@@ -195,22 +208,6 @@ class AwqTest(unittest.TestCase):
|
||||
output = model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_raise_quantization(self):
|
||||
"""
|
||||
Simple test that checks if one passes a quantization config to quantize a model, it raises an error
|
||||
"""
|
||||
quantization_config = AwqConfig(bits=4)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
_ = AutoModelForCausalLM.from_pretrained(
|
||||
self.dummy_transformers_model_name, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models for inference. To quantize transformers models with AWQ algorithm, please refer to our quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization ",
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
@@ -225,3 +222,144 @@ class AwqTest(unittest.TestCase):
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_auto_awq
|
||||
@require_accelerate
|
||||
class AwqFusedTest(unittest.TestCase):
|
||||
model_name = "TheBloke/Mistral-7B-OpenOrca-AWQ"
|
||||
model_revision = "7048b2af77d0dd1c81b000b19d73f9cc8950b510"
|
||||
|
||||
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
|
||||
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
|
||||
|
||||
prompt = (
|
||||
"You're standing on the surface of the Earth. "
|
||||
"You walk one mile south, one mile west and one mile north. "
|
||||
"You end up exactly where you started. Where are you?"
|
||||
)
|
||||
|
||||
EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for"
|
||||
EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def _check_fused_modules(self, model):
|
||||
has_fused_modules = False
|
||||
fused_modules_name = ["QuantAttentionFused", "QuantFusedMLP", "FasterTransformerRMSNorm"]
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if module.__class__.__name__ in fused_modules_name:
|
||||
has_fused_modules = True
|
||||
break
|
||||
|
||||
self.assertTrue(has_fused_modules, "Modules fusing not performed correctly!")
|
||||
|
||||
def test_raise_save_pretrained(self):
|
||||
"""
|
||||
Test that `save_pretrained` is effectively blocked for fused models
|
||||
"""
|
||||
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=self.model_revision,
|
||||
).to(torch_device)
|
||||
|
||||
self._check_fused_modules(model)
|
||||
|
||||
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
def test_generation_fused(self):
|
||||
"""
|
||||
Test generation quality for fused models - single batch case
|
||||
"""
|
||||
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=self.model_revision,
|
||||
).to(torch_device)
|
||||
|
||||
self._check_fused_modules(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision)
|
||||
|
||||
inputs = tokenizer(self.prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
||||
|
||||
def test_generation_fused_batched(self):
|
||||
"""
|
||||
Test generation quality for fused models - multi batch case
|
||||
"""
|
||||
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quantization_config,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=self.model_revision,
|
||||
).to(torch_device)
|
||||
|
||||
self._check_fused_modules(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision)
|
||||
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_generation_custom_model(self):
|
||||
"""
|
||||
Test generation quality for fused models using custom fused map.
|
||||
"""
|
||||
quantization_config = AwqConfig(
|
||||
bits=4,
|
||||
fuse_max_seq_len=512,
|
||||
modules_to_fuse={
|
||||
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
"layernorm": ["ln1", "ln2", "norm"],
|
||||
"mlp": ["gate_proj", "up_proj", "down_proj"],
|
||||
"use_alibi": False,
|
||||
"num_attention_heads": 56,
|
||||
"num_key_value_heads": 8,
|
||||
"hidden_size": 7168,
|
||||
},
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.custom_mapping_model_id,
|
||||
quantization_config=quantization_config,
|
||||
trust_remote_code=True,
|
||||
device_map="balanced",
|
||||
revision=self.custom_model_revision,
|
||||
)
|
||||
|
||||
self._check_fused_modules(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.custom_mapping_model_id, revision=self.custom_model_revision, trust_remote_code=True
|
||||
)
|
||||
|
||||
prompt = "Hello"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)
|
||||
|
||||
Reference in New Issue
Block a user