[internvl] fix chat template (#37656)

* fix chat template

* update

* update conversion

* rename `fake_image_token` in tests
This commit is contained in:
Raushan Turganbay
2025-04-23 16:56:36 +02:00
committed by GitHub
parent 9ec8be56dd
commit 1e9087368c
5 changed files with 88 additions and 120 deletions

View File

@@ -257,6 +257,7 @@ InternVL models can also handle video inputs. Here is an example of how to perfo
... add_generation_prompt=True, ... add_generation_prompt=True,
... tokenize=True, ... tokenize=True,
... return_dict=True, ... return_dict=True,
... num_frames=8,
>>> ).to(model.device, dtype=torch.float16) >>> ).to(model.device, dtype=torch.float16)
>>> output = model.generate(**inputs, max_new_tokens=25) >>> output = model.generate(**inputs, max_new_tokens=25)

View File

@@ -312,6 +312,7 @@ def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None,
"start_image_token": "<img>", "start_image_token": "<img>",
"end_image_token": "</img>", "end_image_token": "</img>",
"context_image_token": "<IMG_CONTEXT>", "context_image_token": "<IMG_CONTEXT>",
"video_token": "<video>",
}, },
) )
tokenizer.model_max_length = CONTEXT_LENGTH tokenizer.model_max_length = CONTEXT_LENGTH

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
from functools import partial from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from transformers.processing_utils import ( from transformers.processing_utils import (
AllKwargsForChatTemplate,
ImagesKwargs, ImagesKwargs,
ProcessingKwargs, ProcessingKwargs,
ProcessorMixin, ProcessorMixin,
@@ -34,6 +32,7 @@ from ...image_utils import (
VideoInput, VideoInput,
VideoMetadata, VideoMetadata,
concatenate_list, concatenate_list,
load_video,
make_batched_videos, make_batched_videos,
make_flat_list_of_images, make_flat_list_of_images,
) )
@@ -75,20 +74,12 @@ class InternVLProcessor(ProcessorMixin):
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2) image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string. in a chat into a tokenizable string.
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
The token to use for the image placeholder in the text. This token will be replaced by the
appropriate image tokens when processing the text with images.
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
The token to use for the video placeholder in the text. This token will be replaced by the
appropriate image tokens when processing the text with videos.
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = [ valid_kwargs = [
"chat_template", "chat_template",
"image_seq_length", "image_seq_length",
"fake_image_token",
"fake_video_token",
] ]
image_processor_class = "AutoImageProcessor" image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
@@ -99,16 +90,14 @@ class InternVLProcessor(ProcessorMixin):
tokenizer=None, tokenizer=None,
image_seq_length: int = 256, image_seq_length: int = 256,
chat_template=None, chat_template=None,
fake_image_token="<image>",
fake_video_token="<video>",
**kwargs, **kwargs,
): ):
self.image_seq_length = image_seq_length self.image_seq_length = image_seq_length
self.fake_image_token = fake_image_token
self.fake_video_token = fake_video_token
self.start_image_token = tokenizer.start_image_token self.start_image_token = tokenizer.start_image_token
self.end_image_token = tokenizer.end_image_token self.end_image_token = tokenizer.end_image_token
self.context_image_token = tokenizer.context_image_token self.image_token = tokenizer.context_image_token
self.video_token = tokenizer.video_token
self.image_token_id = tokenizer.context_image_token_id
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
@@ -131,24 +120,24 @@ class InternVLProcessor(ProcessorMixin):
video_index = 0 video_index = 0
processed_text = [] processed_text = []
image_video_patches = [] image_video_patches = []
replace_strings = []
# Support interleaved image and video in prompts: # Support interleaved image and video in prompts:
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts # Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
for prompt in text: for prompt in text:
new_prompt = prompt new_prompt = prompt
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt: while self.image_token in new_prompt or self.video_token in new_prompt:
if self.fake_image_token in new_prompt and ( if self.image_token in new_prompt and (
self.fake_video_token not in new_prompt self.video_token not in new_prompt
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token) or new_prompt.index(self.image_token) < new_prompt.index(self.video_token)
): ):
# Get the slice of patches corresponding to the current image # Get the slice of patches corresponding to the current image
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0 start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
end_index = image_num_patches_indices[image_index] end_index = image_num_patches_indices[image_index]
image_video_patches.append(image_pixel_values[start_index:end_index]) image_video_patches.append(image_pixel_values[start_index:end_index])
# Replace the corresponding image placeholder with the correct number of image tokens # Replace the corresponding image placeholder with the correct number of image tokens
new_prompt = new_prompt.replace( new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
self.fake_image_token, replace_strings.append(
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}", f"{self.start_image_token}{self.image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}"
1,
) )
image_index += 1 image_index += 1
else: else:
@@ -163,11 +152,15 @@ class InternVLProcessor(ProcessorMixin):
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens # Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
num_patches = list(video_num_patches[current_patch_index:end_patch_index]) num_patches = list(video_num_patches[current_patch_index:end_patch_index])
video_prompt = "\n".join( video_prompt = "\n".join(
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}" f"Frame{i + 1}: {self.start_image_token}{self.image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
for i in range(len(num_patches)) for i in range(len(num_patches))
) )
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1) replace_strings.append(video_prompt)
new_prompt = new_prompt.replace(self.video_token, "<placeholder>", 1)
video_index += 1 video_index += 1
while "<placeholder>" in new_prompt:
replace_str = replace_strings.pop(0)
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
processed_text.append(new_prompt) processed_text.append(new_prompt)
return processed_text, image_video_patches, image_index, video_index return processed_text, image_video_patches, image_index, video_index
@@ -269,9 +262,11 @@ class InternVLProcessor(ProcessorMixin):
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor)) # Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)} image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
return BatchFeature(data={**text_inputs, **image_videos_inputs}) return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
def sample_indices_fn( def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
@@ -290,15 +285,13 @@ class InternVLProcessor(ProcessorMixin):
Returns: Returns:
`np.ndarray`: Array of frame indices to sample. `np.ndarray`: Array of frame indices to sample.
""" """
num_frames = num_frames if num_frames is not None else metadata.total_num_frames
if initial_shift is True: if initial_shift is True:
initial_shift = metadata.total_num_frames / num_frames / 2 initial_shift = metadata.total_num_frames / num_frames / 2
if num_frames is not None: indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
indices = np.arange( int
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames )
).astype(int)
else:
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
return indices return indices
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
@@ -321,58 +314,39 @@ class InternVLProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names) + list(image_processor_input_names) return list(tokenizer_input_names) + list(image_processor_input_names)
# Add model-specific video sampling method when applying the template # TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def apply_chat_template( def _load_video_for_model(
self, self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], video: Union[str, "VideoInput"],
chat_template: Optional[str] = None, num_frames: Optional[int],
num_frames: int = 8, backend: str = "pyav",
initial_shift: Union[bool, float, int] = True, initial_shift: bool = True,
video_load_backend="pyav", **kwargs,
**kwargs: Unpack[AllKwargsForChatTemplate], ) -> np.array:
):
""" """
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input Loads `video` to a numpy array.
conversations to turn them into a single tokenizable string.
The input is expected to be in the following format, where each message content is a list consisting of text and
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
]
Args: Args:
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`): video (`str` or `VideoInput`):
The conversation to format. The video to convert to the numpy array format. Can be a link to video or local path.
chat_template (`Optional[str]`, *optional*): num_frames (`int`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's Number of frames to sample uniformly. If not passed, the whole video is loaded.
chat template is used. backend (`str`, *optional*, defaults to `"pyav"`):
num_frames (`int`, *optional*, defaults to 8): The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
Number of frames to sample from a video when using the default `sample_indices_fn`. initial_shift (`bool`, *optional*, defaults to `True`):
initial_shift (`bool`, `float` or `int`, defaults to `0`): The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
If `True`, the shift is set so that frames are sampled from the middle of the video.
"""
sample_indices_fn = kwargs.pop(
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
)
return super().apply_chat_template( Returns:
conversation, Tuple[`np.array`, Dict]: A tuple containing:
chat_template, - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
video_load_backend=video_load_backend, - Metadata dictionary.
num_frames=num_frames, """
sample_indices_fn=sample_indices_fn,
**kwargs, def sample_indices_fn_func(metadata, **fn_kwargs):
) return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata
__all__ = ["InternVLProcessor"] __all__ = ["InternVLProcessor"]

View File

@@ -296,7 +296,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n" prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad(): with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False) generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -314,7 +316,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n" prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
# Forward # Forward
@@ -378,8 +382,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
) )
# Prepare inputs # Prepare inputs
prompt = [ prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
] ]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
@@ -414,8 +418,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
) )
# Prepare inputs # Prepare inputs
prompt = [ prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
] ]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open( image2 = Image.open(
@@ -485,6 +489,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
tokenize=True, tokenize=True,
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
num_frames=8,
).to(torch_device, dtype=torch.float16) ).to(torch_device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25) output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -552,6 +557,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
num_frames=8,
).to(torch_device, dtype=torch.bfloat16) ).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25) output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -601,7 +607,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n" prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad(): with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False) generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -619,7 +627,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n" prompt = (
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
# Forward # Forward
@@ -687,8 +697,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
) )
# Prepare inputs # Prepare inputs
prompt = [ prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
] ]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
@@ -724,8 +734,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
) )
# Prepare inputs # Prepare inputs
prompt = [ prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
] ]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open( image2 = Image.open(
@@ -795,6 +805,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
tokenize=True, tokenize=True,
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
num_frames=8,
).to(torch_device, dtype=torch.float16) ).to(torch_device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25) output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -862,6 +873,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
num_frames=8,
).to(torch_device, dtype=torch.bfloat16) ).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25) output = model.generate(**inputs, do_sample=False, max_new_tokens=25)

View File

@@ -64,7 +64,8 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
**processor_kwargs, **processor_kwargs,
) )
processor.save_pretrained(cls.tmpdirname) processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.fake_image_token cls.image_token = processor.image_token
cls.video_token = processor.video_token
@staticmethod @staticmethod
def prepare_processor_dict(): def prepare_processor_dict():
@@ -138,6 +139,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
num_frames=8,
) )
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together # Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
@@ -150,6 +152,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
num_frames=8,
) )
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded # We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
torch.testing.assert_close( torch.testing.assert_close(
@@ -223,6 +226,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True, tokenize=True,
return_dict=True, return_dict=True,
return_tensors="np", return_tensors="np",
num_frames=8,
) )
self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertTrue(self.videos_input_name in out_dict_with_video)
@@ -272,30 +276,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
# Load with `video_fps` arg # Load with `video_fps` arg is not possible with InternVL (skip)
video_fps = 1
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=None, # force to use default num_frames
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
# Load with `video_fps` and `num_frames` args, should raise an error
with self.assertRaises(ValueError):
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=num_frames,
)
# Load without any arg should use the default loading method # Load without any arg should use the default loading method
out_dict_with_video = processor.apply_chat_template( out_dict_with_video = processor.apply_chat_template(
@@ -305,8 +286,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return_dict=True, return_dict=True,
) )
self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertTrue(self.videos_input_name in out_dict_with_video)
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 300)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size # Load video as a list of frames (i.e. images). NOTE: each frame should have same size
# because we assume they come from one video # because we assume they come from one video