[fix] LlavaNextProcessor '_get_unpadded_features' method (#33263)

* [fix] LlavaNextProcessor '_get_unpadded_features' method

* [tests] add test_image_token_filling

* [chore] style + comment

* [minor] improve readability

* [chore] run make fix-copies
This commit is contained in:
laurentd-lunit
2024-09-04 21:41:51 +09:00
committed by GitHub
parent d750b509fc
commit d703477265
2 changed files with 30 additions and 2 deletions

View File

@@ -13,6 +13,8 @@
# limitations under the License.
import unittest
import torch
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
@@ -39,3 +41,29 @@ class LlavaProcessorTest(unittest.TestCase):
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_image_token_filling(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
# Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526
image_token_index = 32000
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)