[Mixtral / Awq] Add mixtral fused modules for Awq (#28240)
* add mixtral fused modules * add changes from modeling utils * add test * fix test + rope theta issue * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add tests --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,13 @@ AWQ_FUSED_MAPPINGS = {
|
|||||||
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
|
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
|
||||||
"use_alibi": False,
|
"use_alibi": False,
|
||||||
},
|
},
|
||||||
|
"mixtral": {
|
||||||
|
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
"mlp": ["w1", "w3", "w2"],
|
||||||
|
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
|
||||||
|
"use_alibi": False,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
},
|
||||||
"llama": {
|
"llama": {
|
||||||
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
"mlp": ["gate_proj", "up_proj", "down_proj"],
|
"mlp": ["gate_proj", "up_proj", "down_proj"],
|
||||||
@@ -353,6 +360,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
|
|||||||
previous_device,
|
previous_device,
|
||||||
modules_to_fuse["max_seq_len"],
|
modules_to_fuse["max_seq_len"],
|
||||||
use_alibi=modules_to_fuse["use_alibi"],
|
use_alibi=modules_to_fuse["use_alibi"],
|
||||||
|
# The default value in autoawq is set to 10000.0
|
||||||
|
rope_theta=modules_to_fuse.get("rope_theta", 10000.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
fused_attention_layer.is_hf_transformers = True
|
fused_attention_layer.is_hf_transformers = True
|
||||||
|
|||||||
@@ -3587,7 +3587,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# a `modules_to_not_convert` attribute we need to manually set that attribute into the
|
# a `modules_to_not_convert` attribute we need to manually set that attribute into the
|
||||||
# passed `quantization_config`
|
# passed `quantization_config`
|
||||||
elif (
|
elif (
|
||||||
quantization_config.modules_to_not_convert is None
|
getattr(quantization_config, "modules_to_not_convert", None) is None
|
||||||
and "modules_to_not_convert" in config.quantization_config
|
and "modules_to_not_convert" in config.quantization_config
|
||||||
):
|
):
|
||||||
quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"]
|
quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"]
|
||||||
|
|||||||
@@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
|
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
|
||||||
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
|
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
|
||||||
|
|
||||||
|
mixtral_model_name = "casperhansen/mixtral-instruct-awq"
|
||||||
|
mixtral_model_revision = "87dd4ec502dde74fb3a624835c776b000d190c3b"
|
||||||
|
|
||||||
multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq"
|
multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq"
|
||||||
multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442"
|
multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442"
|
||||||
|
|
||||||
@@ -265,6 +268,7 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for"
|
EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for"
|
||||||
EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org"
|
EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org"
|
||||||
|
EXPECTED_GENERATION_MIXTRAL = prompt + " You're on the North Pole.\n\nThe"
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -300,6 +304,24 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
def test_fused_modules_to_not_convert(self):
|
||||||
|
"""
|
||||||
|
Test if fused + modules to_not_covnert work as expected
|
||||||
|
"""
|
||||||
|
model_id = "hf-internal-testing/Mixtral-tiny-AWQ"
|
||||||
|
|
||||||
|
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Check if model has been correctly fused
|
||||||
|
self._check_fused_modules(model)
|
||||||
|
# Checks if the modules_to_not_convert (here gate layer) is a Linear
|
||||||
|
self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear))
|
||||||
|
|
||||||
def test_generation_fused(self):
|
def test_generation_fused(self):
|
||||||
"""
|
"""
|
||||||
Test generation quality for fused models - single batch case
|
Test generation quality for fused models - single batch case
|
||||||
@@ -408,3 +430,24 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
|
|
||||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)
|
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_generation_mixtral_fused(self):
|
||||||
|
"""
|
||||||
|
Text generation test for Mixtral + AWQ + fused
|
||||||
|
"""
|
||||||
|
quantization_config = AwqConfig(bits=4, fuse_max_seq_len=1024, do_fuse=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.mixtral_model_name,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
device_map="auto",
|
||||||
|
revision=self.mixtral_model_revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.mixtral_model_name)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||||
|
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
|
||||||
|
|||||||
Reference in New Issue
Block a user