Faster generation using AWQ + Fused modules (#27411)

* v1 fusing modules

* add fused mlp support

* up

* fix CI

* block save_pretrained

* fixup

* small fix

* add new condition

* add v1 docs

* add some comments

* style

* fix nit

* adapt from suggestion

* add check

* change arg names

* change variables name

* Update src/transformers/integrations/awq.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* style

* split up into 3 different private methods

* more conditions

* more checks

* add fused tests for custom models

* fix

* fix tests

* final update docs

* final fixes

* fix importlib metadata

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change it to `do_fuse`

* nit

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* few fixes

* revert

* fix test

* fix copies

* raise error if model is not quantized

* add test

* use quantization_config.config when fusing

* Update src/transformers/modeling_utils.py

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-12-05 12:14:45 +01:00
committed by GitHub
parent df40edfb00
commit fdb85be40f
7 changed files with 623 additions and 33 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
@@ -107,6 +108,11 @@ class AwqTest(unittest.TestCase):
device_map=cls.device_map,
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
@@ -158,6 +164,13 @@ class AwqTest(unittest.TestCase):
output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_if_non_quantized(self):
model_id = "facebook/opt-125m"
quantization_config = AwqConfig(bits=4)
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
def test_quantized_model_bf16(self):
"""
Simple test that checks if the quantized model is working properly with bf16
@@ -195,22 +208,6 @@ class AwqTest(unittest.TestCase):
output = model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_quantization(self):
"""
Simple test that checks if one passes a quantization config to quantize a model, it raises an error
"""
quantization_config = AwqConfig(bits=4)
with self.assertRaises(ValueError) as context:
_ = AutoModelForCausalLM.from_pretrained(
self.dummy_transformers_model_name, quantization_config=quantization_config
)
self.assertEqual(
str(context.exception),
"You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models for inference. To quantize transformers models with AWQ algorithm, please refer to our quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization ",
)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
@@ -225,3 +222,144 @@ class AwqTest(unittest.TestCase):
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@slow
@require_torch_gpu
@require_auto_awq
@require_accelerate
class AwqFusedTest(unittest.TestCase):
model_name = "TheBloke/Mistral-7B-OpenOrca-AWQ"
model_revision = "7048b2af77d0dd1c81b000b19d73f9cc8950b510"
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
prompt = (
"You're standing on the surface of the Earth. "
"You walk one mile south, one mile west and one mile north. "
"You end up exactly where you started. Where are you?"
)
EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for"
EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def _check_fused_modules(self, model):
has_fused_modules = False
fused_modules_name = ["QuantAttentionFused", "QuantFusedMLP", "FasterTransformerRMSNorm"]
for _, module in model.named_modules():
if module.__class__.__name__ in fused_modules_name:
has_fused_modules = True
break
self.assertTrue(has_fused_modules, "Modules fusing not performed correctly!")
def test_raise_save_pretrained(self):
"""
Test that `save_pretrained` is effectively blocked for fused models
"""
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
revision=self.model_revision,
).to(torch_device)
self._check_fused_modules(model)
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
def test_generation_fused(self):
"""
Test generation quality for fused models - single batch case
"""
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
revision=self.model_revision,
).to(torch_device)
self._check_fused_modules(model)
tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision)
inputs = tokenizer(self.prompt, return_tensors="pt").to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
def test_generation_fused_batched(self):
"""
Test generation quality for fused models - multi batch case
"""
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
revision=self.model_revision,
).to(torch_device)
self._check_fused_modules(model)
tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision)
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
@require_torch_multi_gpu
def test_generation_custom_model(self):
"""
Test generation quality for fused models using custom fused map.
"""
quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
modules_to_fuse={
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"layernorm": ["ln1", "ln2", "norm"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"use_alibi": False,
"num_attention_heads": 56,
"num_key_value_heads": 8,
"hidden_size": 7168,
},
)
model = AutoModelForCausalLM.from_pretrained(
self.custom_mapping_model_id,
quantization_config=quantization_config,
trust_remote_code=True,
device_map="balanced",
revision=self.custom_model_revision,
)
self._check_fused_modules(model)
tokenizer = AutoTokenizer.from_pretrained(
self.custom_mapping_model_id, revision=self.custom_model_revision, trust_remote_code=True
)
prompt = "Hello"
inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)