Fix continue_final_message for image-text-to-text chat templates (#34236)
* fix continue_final_message for vlms * Add one test for vlms continue_final_message chat template
This commit is contained in:
@@ -1874,7 +1874,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
**template_kwargs,
|
**template_kwargs,
|
||||||
)
|
)
|
||||||
if continue_final_message:
|
if continue_final_message:
|
||||||
final_message = chat[-1]["content"].strip()
|
final_message = chat[-1]["content"]
|
||||||
|
if isinstance(final_message, (list, tuple)):
|
||||||
|
final_message = final_message[-1]["text"]
|
||||||
|
final_message = final_message.strip()
|
||||||
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
|
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
|
||||||
rendered.append(rendered_chat)
|
rendered.append(rendered_chat)
|
||||||
|
|
||||||
|
|||||||
@@ -93,3 +93,24 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
self.assertEqual(expected_prompt, formatted_prompt)
|
self.assertEqual(expected_prompt, formatted_prompt)
|
||||||
|
|
||||||
|
def test_chat_template_with_continue_final_message(self):
|
||||||
|
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||||
|
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "Describe this image."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "There is a dog and"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
prompt = processor.apply_chat_template(messages, continue_final_message=True)
|
||||||
|
self.assertEqual(expected_prompt, prompt)
|
||||||
|
|||||||
Reference in New Issue
Block a user