[chat-template] Unify tests and clean up 🧼 (#37275)

* fix tests and some clean up

* make one general test for each modality

* remove redundant merging of kwargs

* edge cases

* dont enforce slow when reloading

* fix gemma3 tests

* has to adapt llama 4 after rebase

* remove also from overriden tests

* should be green now
This commit is contained in:
Raushan Turganbay
2025-04-10 14:42:32 +02:00
committed by GitHub
parent 10144ff116
commit 1ae8d54b04
18 changed files with 389 additions and 1112 deletions

View File

@@ -86,67 +86,3 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor = LlavaProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
def test_chat_template(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_dict(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
out_dict_with_image = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
def test_chat_template_with_continue_final_message(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe this image."},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "There is a dog and"},
],
},
]
prompt = processor.apply_chat_template(messages, continue_final_message=True)
self.assertEqual(expected_prompt, prompt)