[internvl] fix chat template (#37656)

* fix chat template

* update

* update conversion

* rename `fake_image_token` in tests
This commit is contained in:
Raushan Turganbay
2025-04-23 16:56:36 +02:00
committed by GitHub
parent 9ec8be56dd
commit 1e9087368c
5 changed files with 88 additions and 120 deletions

View File

@@ -64,7 +64,8 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
**processor_kwargs,
)
processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.fake_image_token
cls.image_token = processor.image_token
cls.video_token = processor.video_token
@staticmethod
def prepare_processor_dict():
@@ -138,6 +139,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True,
return_tensors="pt",
padding=True,
num_frames=8,
)
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
@@ -150,6 +152,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True,
return_tensors="pt",
padding=True,
num_frames=8,
)
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
torch.testing.assert_close(
@@ -223,6 +226,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors="np",
num_frames=8,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
@@ -272,30 +276,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
# Load with `video_fps` arg
video_fps = 1
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=None, # force to use default num_frames
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
# Load with `video_fps` and `num_frames` args, should raise an error
with self.assertRaises(ValueError):
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=num_frames,
)
# Load with `video_fps` arg is not possible with InternVL (skip)
# Load without any arg should use the default loading method
out_dict_with_video = processor.apply_chat_template(
@@ -305,8 +286,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 300)
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
# because we assume they come from one video