[internvl] fix chat template (#37656)

* fix chat template

* update

* update conversion

* rename `fake_image_token` in tests
This commit is contained in:
Raushan Turganbay
2025-04-23 16:56:36 +02:00
committed by GitHub
parent 9ec8be56dd
commit 1e9087368c
5 changed files with 88 additions and 120 deletions

View File

@@ -257,6 +257,7 @@ InternVL models can also handle video inputs. Here is an example of how to perfo
... add_generation_prompt=True,
... tokenize=True,
... return_dict=True,
... num_frames=8,
>>> ).to(model.device, dtype=torch.float16)
>>> output = model.generate(**inputs, max_new_tokens=25)

View File

@@ -312,6 +312,7 @@ def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None,
"start_image_token": "<img>",
"end_image_token": "</img>",
"context_image_token": "<IMG_CONTEXT>",
"video_token": "<video>",
},
)
tokenizer.model_max_length = CONTEXT_LENGTH

View File

@@ -14,13 +14,11 @@
# limitations under the License.
from functools import partial
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union
import numpy as np
from transformers.processing_utils import (
AllKwargsForChatTemplate,
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
@@ -34,6 +32,7 @@ from ...image_utils import (
VideoInput,
VideoMetadata,
concatenate_list,
load_video,
make_batched_videos,
make_flat_list_of_images,
)
@@ -75,20 +74,12 @@ class InternVLProcessor(ProcessorMixin):
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
The token to use for the image placeholder in the text. This token will be replaced by the
appropriate image tokens when processing the text with images.
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
The token to use for the video placeholder in the text. This token will be replaced by the
appropriate image tokens when processing the text with videos.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"image_seq_length",
"fake_image_token",
"fake_video_token",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
@@ -99,16 +90,14 @@ class InternVLProcessor(ProcessorMixin):
tokenizer=None,
image_seq_length: int = 256,
chat_template=None,
fake_image_token="<image>",
fake_video_token="<video>",
**kwargs,
):
self.image_seq_length = image_seq_length
self.fake_image_token = fake_image_token
self.fake_video_token = fake_video_token
self.start_image_token = tokenizer.start_image_token
self.end_image_token = tokenizer.end_image_token
self.context_image_token = tokenizer.context_image_token
self.image_token = tokenizer.context_image_token
self.video_token = tokenizer.video_token
self.image_token_id = tokenizer.context_image_token_id
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
@@ -131,24 +120,24 @@ class InternVLProcessor(ProcessorMixin):
video_index = 0
processed_text = []
image_video_patches = []
replace_strings = []
# Support interleaved image and video in prompts:
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
for prompt in text:
new_prompt = prompt
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt:
if self.fake_image_token in new_prompt and (
self.fake_video_token not in new_prompt
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token)
while self.image_token in new_prompt or self.video_token in new_prompt:
if self.image_token in new_prompt and (
self.video_token not in new_prompt
or new_prompt.index(self.image_token) < new_prompt.index(self.video_token)
):
# Get the slice of patches corresponding to the current image
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
end_index = image_num_patches_indices[image_index]
image_video_patches.append(image_pixel_values[start_index:end_index])
# Replace the corresponding image placeholder with the correct number of image tokens
new_prompt = new_prompt.replace(
self.fake_image_token,
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}",
1,
new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
replace_strings.append(
f"{self.start_image_token}{self.image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}"
)
image_index += 1
else:
@@ -163,11 +152,15 @@ class InternVLProcessor(ProcessorMixin):
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
video_prompt = "\n".join(
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
f"Frame{i + 1}: {self.start_image_token}{self.image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
for i in range(len(num_patches))
)
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1)
replace_strings.append(video_prompt)
new_prompt = new_prompt.replace(self.video_token, "<placeholder>", 1)
video_index += 1
while "<placeholder>" in new_prompt:
replace_str = replace_strings.pop(0)
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
processed_text.append(new_prompt)
return processed_text, image_video_patches, image_index, video_index
@@ -269,9 +262,11 @@ class InternVLProcessor(ProcessorMixin):
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
return BatchFeature(data={**text_inputs, **image_videos_inputs})
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
@@ -290,15 +285,13 @@ class InternVLProcessor(ProcessorMixin):
Returns:
`np.ndarray`: Array of frame indices to sample.
"""
num_frames = num_frames if num_frames is not None else metadata.total_num_frames
if initial_shift is True:
initial_shift = metadata.total_num_frames / num_frames / 2
if num_frames is not None:
indices = np.arange(
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames
).astype(int)
else:
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
int
)
return indices
def batch_decode(self, *args, **kwargs):
@@ -321,58 +314,39 @@ class InternVLProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names) + list(image_processor_input_names)
# Add model-specific video sampling method when applying the template
def apply_chat_template(
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def _load_video_for_model(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
chat_template: Optional[str] = None,
num_frames: int = 8,
initial_shift: Union[bool, float, int] = True,
video_load_backend="pyav",
**kwargs: Unpack[AllKwargsForChatTemplate],
):
video: Union[str, "VideoInput"],
num_frames: Optional[int],
backend: str = "pyav",
initial_shift: bool = True,
**kwargs,
) -> np.array:
"""
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
conversations to turn them into a single tokenizable string.
The input is expected to be in the following format, where each message content is a list consisting of text and
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
]
Loads `video` to a numpy array.
Args:
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
chat template is used.
num_frames (`int`, *optional*, defaults to 8):
Number of frames to sample from a video when using the default `sample_indices_fn`.
initial_shift (`bool`, `float` or `int`, defaults to `0`):
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
If `True`, the shift is set so that frames are sampled from the middle of the video.
"""
sample_indices_fn = kwargs.pop(
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
)
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
backend (`str`, *optional*, defaults to `"pyav"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
initial_shift (`bool`, *optional*, defaults to `True`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
return super().apply_chat_template(
conversation,
chat_template,
video_load_backend=video_load_backend,
num_frames=num_frames,
sample_indices_fn=sample_indices_fn,
**kwargs,
)
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
def sample_indices_fn_func(metadata, **fn_kwargs):
return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata
__all__ = ["InternVLProcessor"]

View File

@@ -296,7 +296,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -314,7 +316,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
# Forward
@@ -378,8 +382,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
@@ -414,8 +418,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(
@@ -485,6 +489,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors="pt",
num_frames=8,
).to(torch_device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -552,6 +557,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
return_dict=True,
return_tensors="pt",
padding=True,
num_frames=8,
).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -601,7 +607,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -619,7 +627,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
# Forward
@@ -687,8 +697,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
@@ -724,8 +734,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(
@@ -795,6 +805,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors="pt",
num_frames=8,
).to(torch_device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -862,6 +873,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
return_dict=True,
return_tensors="pt",
padding=True,
num_frames=8,
).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)

View File

@@ -64,7 +64,8 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
**processor_kwargs,
)
processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.fake_image_token
cls.image_token = processor.image_token
cls.video_token = processor.video_token
@staticmethod
def prepare_processor_dict():
@@ -138,6 +139,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True,
return_tensors="pt",
padding=True,
num_frames=8,
)
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
@@ -150,6 +152,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True,
return_tensors="pt",
padding=True,
num_frames=8,
)
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
torch.testing.assert_close(
@@ -223,6 +226,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors="np",
num_frames=8,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
@@ -272,30 +276,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
# Load with `video_fps` arg
video_fps = 1
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=None, # force to use default num_frames
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
# Load with `video_fps` and `num_frames` args, should raise an error
with self.assertRaises(ValueError):
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=num_frames,
)
# Load with `video_fps` arg is not possible with InternVL (skip)
# Load without any arg should use the default loading method
out_dict_with_video = processor.apply_chat_template(
@@ -305,8 +286,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 300)
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
# because we assume they come from one video