diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index c043b8bc7e..f84578d1f3 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -199,8 +199,8 @@ class LlavaNextProcessor(ProcessorMixin): because it divided each image into patches depending on its resolution. Therefore we need to calculate how many patches an image is divided into and get the number of features from that. """ - current_width = patches_height * scale_height - current_height = patches_width * scale_width + current_height = patches_height * scale_height + current_width = patches_width * scale_width original_aspect_ratio = width / height current_aspect_ratio = current_width / current_height diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index 0d2deacdab..c8b58ce798 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -13,6 +13,8 @@ # limitations under the License. import unittest +import torch + from transformers.testing_utils import require_vision from transformers.utils import is_vision_available @@ -39,3 +41,29 @@ class LlavaProcessorTest(unittest.TestCase): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) + + def test_image_token_filling(self): + processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + processor.patch_size = 14 + processor.vision_feature_select_strategy = "default" + # Important to check with non square image + image = torch.randint(0, 2, (3, 500, 316)) + expected_image_tokens = 1526 + image_token_index = 32000 + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens)