[Awq] Add llava fused modules support (#28239)
* add llava + fused modules * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase):
|
||||
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
|
||||
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"
|
||||
|
||||
multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq"
|
||||
multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442"
|
||||
|
||||
prompt = (
|
||||
"You're standing on the surface of the Earth. "
|
||||
"You walk one mile south, one mile west and one mile north. "
|
||||
@@ -344,6 +347,29 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
||||
|
||||
def test_generation_llava_fused(self):
|
||||
from transformers import pipeline
|
||||
|
||||
quantization_config = AwqConfig(do_fuse=True, fuse_max_seq_len=2048)
|
||||
|
||||
pipe = pipeline(
|
||||
"image-to-text",
|
||||
model=self.multi_modal_model_name,
|
||||
device=0,
|
||||
model_kwargs={
|
||||
"quantization_config": quantization_config,
|
||||
},
|
||||
revision=self.multi_modal_model_code_revision,
|
||||
)
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
|
||||
|
||||
prompt = "USER: <image>\nCan you please describe this image?\nASSISTANT:"
|
||||
|
||||
outputs = pipe(url, prompt=prompt, generate_kwargs={"max_new_tokens": 100})
|
||||
EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on a green surface, possibly a carpet or a grassy area. The cat is holding a red ball in its paws, seemingly playing with it. The cat appears to be focused on the ball, possibly preparing to play or just enjoying the toy."
|
||||
|
||||
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_generation_custom_model(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user