[Awq] Add llava fused modules support (#28239)

* add llava + fused modules

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2024-01-12 06:55:54 +01:00
committed by GitHub
parent 995a7ce9a8
commit 07bdbebb48
5 changed files with 80 additions and 7 deletions

View File

@@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase):
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq"
multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442"
prompt = (
"You're standing on the surface of the Earth. "
"You walk one mile south, one mile west and one mile north. "
@@ -344,6 +347,29 @@ class AwqFusedTest(unittest.TestCase):
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
def test_generation_llava_fused(self):
from transformers import pipeline
quantization_config = AwqConfig(do_fuse=True, fuse_max_seq_len=2048)
pipe = pipeline(
"image-to-text",
model=self.multi_modal_model_name,
device=0,
model_kwargs={
"quantization_config": quantization_config,
},
revision=self.multi_modal_model_code_revision,
)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
prompt = "USER: <image>\nCan you please describe this image?\nASSISTANT:"
outputs = pipe(url, prompt=prompt, generate_kwargs={"max_new_tokens": 100})
EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on a green surface, possibly a carpet or a grassy area. The cat is holding a red ball in its paws, seemingly playing with it. The cat appears to be focused on the ball, possibly preparing to play or just enjoying the toy."
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_generation_custom_model(self):
"""