From d70347726577e9823e35c11883e98c5b2c520b37 Mon Sep 17 00:00:00 2001 From: laurentd-lunit <103402801+laurentd-lunit@users.noreply.github.com> Date: Wed, 4 Sep 2024 21:41:51 +0900 Subject: [PATCH] [fix] LlavaNextProcessor '_get_unpadded_features' method (#33263) * [fix] LlavaNextProcessor '_get_unpadded_features' method * [tests] add test_image_token_filling * [chore] style + comment * [minor] improve readability * [chore] run make fix-copies --- .../llava_next/processing_llava_next.py | 4 +-- .../llava_next/test_processor_llava_next.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index c043b8bc7e..f84578d1f3 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -199,8 +199,8 @@ class LlavaNextProcessor(ProcessorMixin): because it divided each image into patches depending on its resolution. Therefore we need to calculate how many patches an image is divided into and get the number of features from that. """ - current_width = patches_height * scale_height - current_height = patches_width * scale_width + current_height = patches_height * scale_height + current_width = patches_width * scale_width original_aspect_ratio = width / height current_aspect_ratio = current_width / current_height diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index 0d2deacdab..c8b58ce798 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -13,6 +13,8 @@ # limitations under the License. import unittest +import torch + from transformers.testing_utils import require_vision from transformers.utils import is_vision_available @@ -39,3 +41,29 @@ class LlavaProcessorTest(unittest.TestCase): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) + + def test_image_token_filling(self): + processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + processor.patch_size = 14 + processor.vision_feature_select_strategy = "default" + # Important to check with non square image + image = torch.randint(0, 2, (3, 500, 316)) + expected_image_tokens = 1526 + image_token_index = 32000 + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens)