[Awq] Add llava fused modules support (#28239)
* add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -36,6 +36,12 @@ AWQ_FUSED_MAPPINGS = {
|
|||||||
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
|
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
|
||||||
"use_alibi": False,
|
"use_alibi": False,
|
||||||
},
|
},
|
||||||
|
"llava": {
|
||||||
|
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
"mlp": ["gate_proj", "up_proj", "down_proj"],
|
||||||
|
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
|
||||||
|
"use_alibi": False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -143,10 +149,16 @@ def get_modules_to_fuse(model, quantization_config):
|
|||||||
elif model.config.model_type in AWQ_FUSED_MAPPINGS:
|
elif model.config.model_type in AWQ_FUSED_MAPPINGS:
|
||||||
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
|
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
|
||||||
|
|
||||||
|
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
|
||||||
|
if not hasattr(model.config, "text_config"):
|
||||||
|
config = model.config
|
||||||
|
else:
|
||||||
|
config = model.config.text_config
|
||||||
|
|
||||||
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
|
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
|
||||||
hidden_size = model.config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
num_attention_heads = model.config.num_attention_heads
|
num_attention_heads = config.num_attention_heads
|
||||||
num_key_value_heads = getattr(model.config, "num_key_value_heads", num_attention_heads)
|
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
|
||||||
|
|
||||||
# Fill `current_fused_mapping` with the expected values
|
# Fill `current_fused_mapping` with the expected values
|
||||||
current_fused_mapping["hidden_size"] = hidden_size
|
current_fused_mapping["hidden_size"] = hidden_size
|
||||||
@@ -178,6 +190,7 @@ def fuse_awq_modules(model, quantization_config):
|
|||||||
backend = awq_config.backend
|
backend = awq_config.backend
|
||||||
|
|
||||||
modules_to_fuse = get_modules_to_fuse(model, awq_config)
|
modules_to_fuse = get_modules_to_fuse(model, awq_config)
|
||||||
|
modules_to_not_convert = getattr(awq_config, "modules_to_not_convert", None)
|
||||||
|
|
||||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||||
from awq.modules.fused.attn import QuantAttentionFused
|
from awq.modules.fused.attn import QuantAttentionFused
|
||||||
@@ -187,6 +200,10 @@ def fuse_awq_modules(model, quantization_config):
|
|||||||
raise ValueError("Fusing is only supported for the AutoAWQ backend")
|
raise ValueError("Fusing is only supported for the AutoAWQ backend")
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
if modules_to_not_convert is not None:
|
||||||
|
if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
|
||||||
|
continue
|
||||||
|
|
||||||
# Replace layer norms
|
# Replace layer norms
|
||||||
_fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)
|
_fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)
|
||||||
|
|
||||||
@@ -248,7 +265,14 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
|
|||||||
down_proj = getattr(module, fuse_module_names[2])
|
down_proj = getattr(module, fuse_module_names[2])
|
||||||
|
|
||||||
previous_device = gate_proj.qweight.device
|
previous_device = gate_proj.qweight.device
|
||||||
activation_fn = ACT2FN[model.config.hidden_act]
|
|
||||||
|
# Deal also with the case model has `text_config` attribute
|
||||||
|
hidden_act = (
|
||||||
|
model.config.hidden_act
|
||||||
|
if not hasattr(model.config, "text_config")
|
||||||
|
else model.config.text_config.hidden_act
|
||||||
|
)
|
||||||
|
activation_fn = ACT2FN[hidden_act]
|
||||||
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
|
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
|
||||||
|
|
||||||
parent_name, child_name = current_module_name.rsplit(".", 1)
|
parent_name, child_name = current_module_name.rsplit(".", 1)
|
||||||
@@ -284,7 +308,6 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
|
|||||||
if hasattr(module, modules_to_fuse["attention"][0]):
|
if hasattr(module, modules_to_fuse["attention"][0]):
|
||||||
# First, we pack the QKV layers together
|
# First, we pack the QKV layers together
|
||||||
q_proj = getattr(module, modules_to_fuse["attention"][0])
|
q_proj = getattr(module, modules_to_fuse["attention"][0])
|
||||||
previous_device = q_proj.qweight.device
|
|
||||||
|
|
||||||
if isinstance(q_proj, WQLinear_GEMV):
|
if isinstance(q_proj, WQLinear_GEMV):
|
||||||
linear_target_cls = WQLinear_GEMV
|
linear_target_cls = WQLinear_GEMV
|
||||||
@@ -295,6 +318,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported q_proj type: {type(q_proj)}")
|
raise ValueError("Unsupported q_proj type: {type(q_proj)}")
|
||||||
|
|
||||||
|
previous_device = q_proj.qweight.device
|
||||||
|
|
||||||
k_proj = getattr(module, modules_to_fuse["attention"][1])
|
k_proj = getattr(module, modules_to_fuse["attention"][1])
|
||||||
v_proj = getattr(module, modules_to_fuse["attention"][2])
|
v_proj = getattr(module, modules_to_fuse["attention"][2])
|
||||||
o_proj = getattr(module, modules_to_fuse["attention"][3])
|
o_proj = getattr(module, modules_to_fuse["attention"][3])
|
||||||
|
|||||||
@@ -3583,6 +3583,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
quantization_config = AwqConfig.from_dict(config.quantization_config)
|
quantization_config = AwqConfig.from_dict(config.quantization_config)
|
||||||
|
# In case a user passes a `AwqConfig` with `do_fuse=True` for models that have
|
||||||
|
# a `modules_to_not_convert` attribute we need to manually set that attribute into the
|
||||||
|
# passed `quantization_config`
|
||||||
|
elif (
|
||||||
|
quantization_config.modules_to_not_convert is None
|
||||||
|
and "modules_to_not_convert" in config.quantization_config
|
||||||
|
):
|
||||||
|
quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"]
|
||||||
|
|
||||||
if quantization_config.modules_to_not_convert is not None:
|
if quantization_config.modules_to_not_convert is not None:
|
||||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||||
|
|||||||
@@ -453,8 +453,15 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out only the tokens that can be un-attended, this can happen
|
||||||
|
# if one uses Llava + Fused modules where the cache on the
|
||||||
|
# first iteration is already big enough, or if one passes custom cache
|
||||||
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||||
|
new_batch_index = batch_index[valid_indices]
|
||||||
|
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||||
|
|
||||||
# Zero-out the places where we don't need to attend
|
# Zero-out the places where we don't need to attend
|
||||||
extended_attention_mask[batch_index, non_attended_tokens] = 0
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
|||||||
@@ -452,8 +452,15 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out only the tokens that can be un-attended, this can happen
|
||||||
|
# in the case one uses Llava + Fused modules where the cache on the
|
||||||
|
# first iteration is already big enough, or if one passes custom cache
|
||||||
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||||
|
new_batch_index = batch_index[valid_indices]
|
||||||
|
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||||
|
|
||||||
# Zero-out the places where we don't need to attend
|
# Zero-out the places where we don't need to attend
|
||||||
extended_attention_mask[batch_index, non_attended_tokens] = 0
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
|||||||
@@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
|
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
|
||||||
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
|
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
|
||||||
|
|
||||||
|
multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq"
|
||||||
|
multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442"
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"You're standing on the surface of the Earth. "
|
"You're standing on the surface of the Earth. "
|
||||||
"You walk one mile south, one mile west and one mile north. "
|
"You walk one mile south, one mile west and one mile north. "
|
||||||
@@ -344,6 +347,29 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
||||||
|
|
||||||
|
def test_generation_llava_fused(self):
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
quantization_config = AwqConfig(do_fuse=True, fuse_max_seq_len=2048)
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
"image-to-text",
|
||||||
|
model=self.multi_modal_model_name,
|
||||||
|
device=0,
|
||||||
|
model_kwargs={
|
||||||
|
"quantization_config": quantization_config,
|
||||||
|
},
|
||||||
|
revision=self.multi_modal_model_code_revision,
|
||||||
|
)
|
||||||
|
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
|
||||||
|
|
||||||
|
prompt = "USER: <image>\nCan you please describe this image?\nASSISTANT:"
|
||||||
|
|
||||||
|
outputs = pipe(url, prompt=prompt, generate_kwargs={"max_new_tokens": 100})
|
||||||
|
EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on a green surface, possibly a carpet or a grassy area. The cat is holding a red ball in its paws, seemingly playing with it. The cat appears to be focused on the ball, possibly preparing to play or just enjoying the toy."
|
||||||
|
|
||||||
|
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_generation_custom_model(self):
|
def test_generation_custom_model(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user