From 07bdbebb48a9fe1e748348e4e14ae0b4659e54c4 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 12 Jan 2024 06:55:54 +0100 Subject: [PATCH] [`Awq`] Add llava fused modules support (#28239) * add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/integrations/awq.py | 35 ++++++++++++++++--- src/transformers/modeling_utils.py | 8 +++++ .../models/llava/modeling_llava.py | 9 ++++- .../models/vipllava/modeling_vipllava.py | 9 ++++- tests/quantization/autoawq/test_awq.py | 26 ++++++++++++++ 5 files changed, 80 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index 336a216e40..dea74b2f7c 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -36,6 +36,12 @@ AWQ_FUSED_MAPPINGS = { "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], "use_alibi": False, }, + "llava": { + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], + "use_alibi": False, + }, } @@ -143,10 +149,16 @@ def get_modules_to_fuse(model, quantization_config): elif model.config.model_type in AWQ_FUSED_MAPPINGS: current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type] + # Properly deal with the case where we have a multi-modal model as well (e.g. Llava) + if not hasattr(model.config, "text_config"): + config = model.config + else: + config = model.config.text_config + # Handle hidden_size, num_attention_heads, num_key_value_heads on our own. - hidden_size = model.config.hidden_size - num_attention_heads = model.config.num_attention_heads - num_key_value_heads = getattr(model.config, "num_key_value_heads", num_attention_heads) + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads) # Fill `current_fused_mapping` with the expected values current_fused_mapping["hidden_size"] = hidden_size @@ -178,6 +190,7 @@ def fuse_awq_modules(model, quantization_config): backend = awq_config.backend modules_to_fuse = get_modules_to_fuse(model, awq_config) + modules_to_not_convert = getattr(awq_config, "modules_to_not_convert", None) if backend == AwqBackendPackingMethod.AUTOAWQ: from awq.modules.fused.attn import QuantAttentionFused @@ -187,6 +200,10 @@ def fuse_awq_modules(model, quantization_config): raise ValueError("Fusing is only supported for the AutoAWQ backend") for name, module in model.named_modules(): + if modules_to_not_convert is not None: + if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert): + continue + # Replace layer norms _fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm) @@ -248,7 +265,14 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_ down_proj = getattr(module, fuse_module_names[2]) previous_device = gate_proj.qweight.device - activation_fn = ACT2FN[model.config.hidden_act] + + # Deal also with the case model has `text_config` attribute + hidden_act = ( + model.config.hidden_act + if not hasattr(model.config, "text_config") + else model.config.text_config.hidden_act + ) + activation_fn = ACT2FN[hidden_act] new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn) parent_name, child_name = current_module_name.rsplit(".", 1) @@ -284,7 +308,6 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na if hasattr(module, modules_to_fuse["attention"][0]): # First, we pack the QKV layers together q_proj = getattr(module, modules_to_fuse["attention"][0]) - previous_device = q_proj.qweight.device if isinstance(q_proj, WQLinear_GEMV): linear_target_cls = WQLinear_GEMV @@ -295,6 +318,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na else: raise ValueError("Unsupported q_proj type: {type(q_proj)}") + previous_device = q_proj.qweight.device + k_proj = getattr(module, modules_to_fuse["attention"][1]) v_proj = getattr(module, modules_to_fuse["attention"][2]) o_proj = getattr(module, modules_to_fuse["attention"][3]) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1958fff73f..2dc5a15a89 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3583,6 +3583,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if quantization_config is None: quantization_config = AwqConfig.from_dict(config.quantization_config) + # In case a user passes a `AwqConfig` with `do_fuse=True` for models that have + # a `modules_to_not_convert` attribute we need to manually set that attribute into the + # passed `quantization_config` + elif ( + quantization_config.modules_to_not_convert is None + and "modules_to_not_convert" in config.quantization_config + ): + quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"] if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index bd205e0fc9..4264af04a4 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -453,8 +453,15 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): device=attention_mask.device, ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend - extended_attention_mask[batch_index, non_attended_tokens] = 0 + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 748c64b22e..ecb8613a7e 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -452,8 +452,15 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): device=attention_mask.device, ) + # Filter out only the tokens that can be un-attended, this can happen + # in the case one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend - extended_attention_mask[batch_index, non_attended_tokens] = 0 + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 3f5118635a..6ce7fca8fc 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase): custom_mapping_model_id = "TheBloke/Yi-34B-AWQ" custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589" + multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq" + multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442" + prompt = ( "You're standing on the surface of the Earth. " "You walk one mile south, one mile west and one mile north. " @@ -344,6 +347,29 @@ class AwqFusedTest(unittest.TestCase): self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION) + def test_generation_llava_fused(self): + from transformers import pipeline + + quantization_config = AwqConfig(do_fuse=True, fuse_max_seq_len=2048) + + pipe = pipeline( + "image-to-text", + model=self.multi_modal_model_name, + device=0, + model_kwargs={ + "quantization_config": quantization_config, + }, + revision=self.multi_modal_model_code_revision, + ) + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png" + + prompt = "USER: \nCan you please describe this image?\nASSISTANT:" + + outputs = pipe(url, prompt=prompt, generate_kwargs={"max_new_tokens": 100}) + EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on a green surface, possibly a carpet or a grassy area. The cat is holding a red ball in its paws, seemingly playing with it. The cat appears to be focused on the ball, possibly preparing to play or just enjoying the toy." + + self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT) + @require_torch_multi_gpu def test_generation_custom_model(self): """