[internvl] fix chat template (#37656)
* fix chat template * update * update conversion * rename `fake_image_token` in tests
This commit is contained in:
committed by
GitHub
parent
9ec8be56dd
commit
1e9087368c
@@ -257,6 +257,7 @@ InternVL models can also handle video inputs. Here is an example of how to perfo
|
||||
... add_generation_prompt=True,
|
||||
... tokenize=True,
|
||||
... return_dict=True,
|
||||
... num_frames=8,
|
||||
>>> ).to(model.device, dtype=torch.float16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
@@ -312,6 +312,7 @@ def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None,
|
||||
"start_image_token": "<img>",
|
||||
"end_image_token": "</img>",
|
||||
"context_image_token": "<IMG_CONTEXT>",
|
||||
"video_token": "<video>",
|
||||
},
|
||||
)
|
||||
tokenizer.model_max_length = CONTEXT_LENGTH
|
||||
|
||||
@@ -14,13 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.processing_utils import (
|
||||
AllKwargsForChatTemplate,
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
@@ -34,6 +32,7 @@ from ...image_utils import (
|
||||
VideoInput,
|
||||
VideoMetadata,
|
||||
concatenate_list,
|
||||
load_video,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
)
|
||||
@@ -75,20 +74,12 @@ class InternVLProcessor(ProcessorMixin):
|
||||
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
The token to use for the image placeholder in the text. This token will be replaced by the
|
||||
appropriate image tokens when processing the text with images.
|
||||
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
The token to use for the video placeholder in the text. This token will be replaced by the
|
||||
appropriate image tokens when processing the text with videos.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"image_seq_length",
|
||||
"fake_image_token",
|
||||
"fake_video_token",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
@@ -99,16 +90,14 @@ class InternVLProcessor(ProcessorMixin):
|
||||
tokenizer=None,
|
||||
image_seq_length: int = 256,
|
||||
chat_template=None,
|
||||
fake_image_token="<image>",
|
||||
fake_video_token="<video>",
|
||||
**kwargs,
|
||||
):
|
||||
self.image_seq_length = image_seq_length
|
||||
self.fake_image_token = fake_image_token
|
||||
self.fake_video_token = fake_video_token
|
||||
self.start_image_token = tokenizer.start_image_token
|
||||
self.end_image_token = tokenizer.end_image_token
|
||||
self.context_image_token = tokenizer.context_image_token
|
||||
self.image_token = tokenizer.context_image_token
|
||||
self.video_token = tokenizer.video_token
|
||||
self.image_token_id = tokenizer.context_image_token_id
|
||||
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||
|
||||
@@ -131,24 +120,24 @@ class InternVLProcessor(ProcessorMixin):
|
||||
video_index = 0
|
||||
processed_text = []
|
||||
image_video_patches = []
|
||||
replace_strings = []
|
||||
# Support interleaved image and video in prompts:
|
||||
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
|
||||
for prompt in text:
|
||||
new_prompt = prompt
|
||||
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt:
|
||||
if self.fake_image_token in new_prompt and (
|
||||
self.fake_video_token not in new_prompt
|
||||
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token)
|
||||
while self.image_token in new_prompt or self.video_token in new_prompt:
|
||||
if self.image_token in new_prompt and (
|
||||
self.video_token not in new_prompt
|
||||
or new_prompt.index(self.image_token) < new_prompt.index(self.video_token)
|
||||
):
|
||||
# Get the slice of patches corresponding to the current image
|
||||
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
|
||||
end_index = image_num_patches_indices[image_index]
|
||||
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||
# Replace the corresponding image placeholder with the correct number of image tokens
|
||||
new_prompt = new_prompt.replace(
|
||||
self.fake_image_token,
|
||||
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}",
|
||||
1,
|
||||
new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
|
||||
replace_strings.append(
|
||||
f"{self.start_image_token}{self.image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}"
|
||||
)
|
||||
image_index += 1
|
||||
else:
|
||||
@@ -163,11 +152,15 @@ class InternVLProcessor(ProcessorMixin):
|
||||
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
|
||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||
video_prompt = "\n".join(
|
||||
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
|
||||
f"Frame{i + 1}: {self.start_image_token}{self.image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
|
||||
for i in range(len(num_patches))
|
||||
)
|
||||
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1)
|
||||
replace_strings.append(video_prompt)
|
||||
new_prompt = new_prompt.replace(self.video_token, "<placeholder>", 1)
|
||||
video_index += 1
|
||||
while "<placeholder>" in new_prompt:
|
||||
replace_str = replace_strings.pop(0)
|
||||
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
|
||||
processed_text.append(new_prompt)
|
||||
|
||||
return processed_text, image_video_patches, image_index, video_index
|
||||
@@ -269,9 +262,11 @@ class InternVLProcessor(ProcessorMixin):
|
||||
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
|
||||
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_videos_inputs})
|
||||
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
|
||||
|
||||
def sample_indices_fn(
|
||||
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
|
||||
@@ -290,15 +285,13 @@ class InternVLProcessor(ProcessorMixin):
|
||||
Returns:
|
||||
`np.ndarray`: Array of frame indices to sample.
|
||||
"""
|
||||
num_frames = num_frames if num_frames is not None else metadata.total_num_frames
|
||||
|
||||
if initial_shift is True:
|
||||
initial_shift = metadata.total_num_frames / num_frames / 2
|
||||
if num_frames is not None:
|
||||
indices = np.arange(
|
||||
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames
|
||||
).astype(int)
|
||||
else:
|
||||
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
|
||||
|
||||
indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
|
||||
int
|
||||
)
|
||||
return indices
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
@@ -321,58 +314,39 @@ class InternVLProcessor(ProcessorMixin):
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(tokenizer_input_names) + list(image_processor_input_names)
|
||||
|
||||
# Add model-specific video sampling method when applying the template
|
||||
def apply_chat_template(
|
||||
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
|
||||
def _load_video_for_model(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
chat_template: Optional[str] = None,
|
||||
num_frames: int = 8,
|
||||
initial_shift: Union[bool, float, int] = True,
|
||||
video_load_backend="pyav",
|
||||
**kwargs: Unpack[AllKwargsForChatTemplate],
|
||||
):
|
||||
video: Union[str, "VideoInput"],
|
||||
num_frames: Optional[int],
|
||||
backend: str = "pyav",
|
||||
initial_shift: bool = True,
|
||||
**kwargs,
|
||||
) -> np.array:
|
||||
"""
|
||||
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
|
||||
conversations to turn them into a single tokenizable string.
|
||||
|
||||
The input is expected to be in the following format, where each message content is a list consisting of text and
|
||||
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
|
||||
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "Please describe this image in detail."},
|
||||
],
|
||||
},
|
||||
]
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
|
||||
The conversation to format.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||
chat template is used.
|
||||
num_frames (`int`, *optional*, defaults to 8):
|
||||
Number of frames to sample from a video when using the default `sample_indices_fn`.
|
||||
initial_shift (`bool`, `float` or `int`, defaults to `0`):
|
||||
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
|
||||
If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||
"""
|
||||
sample_indices_fn = kwargs.pop(
|
||||
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
|
||||
)
|
||||
video (`str` or `VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||
backend (`str`, *optional*, defaults to `"pyav"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
|
||||
initial_shift (`bool`, *optional*, defaults to `True`):
|
||||
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||
|
||||
return super().apply_chat_template(
|
||||
conversation,
|
||||
chat_template,
|
||||
video_load_backend=video_load_backend,
|
||||
num_frames=num_frames,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
**kwargs,
|
||||
)
|
||||
Returns:
|
||||
Tuple[`np.array`, Dict]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- Metadata dictionary.
|
||||
"""
|
||||
|
||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||
return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)
|
||||
|
||||
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
|
||||
return video, metadata
|
||||
|
||||
|
||||
__all__ = ["InternVLProcessor"]
|
||||
|
||||
@@ -296,7 +296,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -314,7 +316,9 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
# Forward
|
||||
@@ -378,8 +382,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
@@ -414,8 +418,8 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(
|
||||
@@ -485,6 +489,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
@@ -552,6 +557,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
@@ -601,7 +607,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -619,7 +627,9 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt = (
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
# Forward
|
||||
@@ -687,8 +697,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
@@ -724,8 +734,8 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(
|
||||
@@ -795,6 +805,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
@@ -862,6 +873,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
@@ -64,7 +64,8 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
**processor_kwargs,
|
||||
)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
cls.image_token = processor.fake_image_token
|
||||
cls.image_token = processor.image_token
|
||||
cls.video_token = processor.video_token
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
@@ -138,6 +139,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
)
|
||||
|
||||
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
|
||||
@@ -150,6 +152,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
num_frames=8,
|
||||
)
|
||||
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
|
||||
torch.testing.assert_close(
|
||||
@@ -223,6 +226,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
num_frames=8,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
@@ -272,30 +276,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=None, # force to use default num_frames
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
# Load with `video_fps` arg is not possible with InternVL (skip)
|
||||
|
||||
# Load without any arg should use the default loading method
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
@@ -305,8 +286,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 300)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
|
||||
Reference in New Issue
Block a user