From fad15fba78e4603cd20695757ad899a6687485f9 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 26 Jul 2024 10:17:27 +0500 Subject: [PATCH] Llava: generate without images (#32183) * llava w/o images * tests --- src/transformers/models/llava/processing_llava.py | 6 +++--- tests/models/llava/test_modeling_llava.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index e51e6ba076..a563b1cb82 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -104,14 +104,14 @@ class LlavaProcessor(ProcessorMixin): - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if images is not None: - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] + image_inputs = self.image_processor(images, return_tensors=return_tensors) else: - pixel_values = None + image_inputs = {} text_inputs = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length ) - return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + return BatchFeature(data={**text_inputs, **image_inputs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index b37e4df3cc..ce13ab6738 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -458,3 +458,16 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): EXPECTED_OUTPUT = ['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n'] # fmt: skip self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) + + @slow + @require_bitsandbytes + def test_generation_no_images(self): + model_id = "llava-hf/llava-1.5-7b-hf" + model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + # Prepare inputs with no images + inputs = processor("Hello, I am", return_tensors="pt").to(torch_device) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20)