[fix] LlavaNextProcessor '_get_unpadded_features' method (#33263)
* [fix] LlavaNextProcessor '_get_unpadded_features' method * [tests] add test_image_token_filling * [chore] style + comment * [minor] improve readability * [chore] run make fix-copies
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
@@ -39,3 +41,29 @@ class LlavaProcessorTest(unittest.TestCase):
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
def test_image_token_filling(self):
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
processor.patch_size = 14
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
# Important to check with non square image
|
||||
image = torch.randint(0, 2, (3, 500, 316))
|
||||
expected_image_tokens = 1526
|
||||
image_token_index = 32000
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor(
|
||||
text=[processor.apply_chat_template(messages)],
|
||||
images=[image],
|
||||
return_tensors="pt",
|
||||
)
|
||||
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||
self.assertEqual(expected_image_tokens, image_tokens)
|
||||
|
||||
Reference in New Issue
Block a user