[fix] LlavaNextProcessor '_get_unpadded_features' method (#33263)
* [fix] LlavaNextProcessor '_get_unpadded_features' method * [tests] add test_image_token_filling * [chore] style + comment * [minor] improve readability * [chore] run make fix-copies
This commit is contained in:
@@ -199,8 +199,8 @@ class LlavaNextProcessor(ProcessorMixin):
|
|||||||
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
|
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
|
||||||
patches an image is divided into and get the number of features from that.
|
patches an image is divided into and get the number of features from that.
|
||||||
"""
|
"""
|
||||||
current_width = patches_height * scale_height
|
current_height = patches_height * scale_height
|
||||||
current_height = patches_width * scale_width
|
current_width = patches_width * scale_width
|
||||||
|
|
||||||
original_aspect_ratio = width / height
|
original_aspect_ratio = width / height
|
||||||
current_aspect_ratio = current_width / current_height
|
current_aspect_ratio = current_width / current_height
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers.testing_utils import require_vision
|
from transformers.testing_utils import require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
@@ -39,3 +41,29 @@ class LlavaProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
self.assertEqual(expected_prompt, formatted_prompt)
|
self.assertEqual(expected_prompt, formatted_prompt)
|
||||||
|
|
||||||
|
def test_image_token_filling(self):
|
||||||
|
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||||
|
processor.patch_size = 14
|
||||||
|
processor.vision_feature_select_strategy = "default"
|
||||||
|
# Important to check with non square image
|
||||||
|
image = torch.randint(0, 2, (3, 500, 316))
|
||||||
|
expected_image_tokens = 1526
|
||||||
|
image_token_index = 32000
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
inputs = processor(
|
||||||
|
text=[processor.apply_chat_template(messages)],
|
||||||
|
images=[image],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||||
|
self.assertEqual(expected_image_tokens, image_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user