Prepare processors for VideoLLMs (#36149)

* allow processor to preprocess conversation + video metadata

* allow callable

* add test

* fix test

* nit: fix

* add metadata frames_indices

* Update src/transformers/processing_utils.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* Update src/transformers/processing_utils.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* port updates from Orr and add one more test

* Update src/transformers/processing_utils.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* typo

* as dataclass

* style

* docstring + maek sure tests green

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-02-14 11:34:08 +01:00
committed by GitHub
parent 33d1d715b0
commit 15ec971b8e
3 changed files with 428 additions and 138 deletions

View File

@@ -18,7 +18,7 @@ import os
from contextlib import redirect_stdout from contextlib import redirect_stdout
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import requests import requests
@@ -126,6 +126,14 @@ class AnnotionFormat(ExplicitEnum):
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
@dataclass
class VideoMetadata:
total_num_frames: int
fps: float
duration: float
video_backend: str
AnnotationType = Dict[str, Union[int, str, List[Dict]]] AnnotationType = Dict[str, Union[int, str, List[Dict]]]
@@ -541,62 +549,83 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image return image
def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None): def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
""" """
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames` A default sampling function that replicates the logic used in get_uniform_frame_indices,
when loading a video. while optionally handling `fps` if `num_frames` is not provided.
Args: Args:
total_num_frames (`int`): metadata (`VideoMetadata`):
Total number of frames that a video has. `VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*): num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled. Number of frames to sample uniformly.
Returns:
np.ndarray: np array of frame indices that will be sampled.
"""
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
else:
indices = np.arange(0, total_num_frames).astype(int)
return indices
def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
"""
Decode the video with open-cv decoder.
Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`.
If not specified and `fps==None`, all frames are sampled.
fps (`int`, *optional*): fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`. Desired frames per second. Takes priority over num_frames if both are provided.
If not specified and `num_frames==None`, all frames are sampled.
Returns: Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). `np.ndarray`: Array of frame indices to sample.
""" """
video = cv2.VideoCapture(video_path) total_num_frames = metadata.total_num_frames
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) video_fps = metadata.fps
video_fps = video.get(cv2.CAP_PROP_FPS)
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None: if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps) num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames: if num_frames > total_num_frames:
raise ValueError( raise ValueError(
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
) )
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
else:
indices = np.arange(0, total_num_frames, dtype=int)
return indices
def read_video_opencv(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode a video using the OpenCV backend.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = video.get(cv2.CAP_PROP_FPS)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
index = 0 index = 0
frames = [] frames = []
while video.isOpened(): while video.isOpened():
success, frame = video.read() success, frame = video.read()
if not success:
break
if index in indices: if index in indices:
height, width, channel = frame.shape height, width, channel = frame.shape
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame[0:height, 0:width, 0:channel]) frames.append(frame[0:height, 0:width, 0:channel])
if success: if success:
index += 1 index += 1
@@ -604,70 +633,81 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Op
break break
video.release() video.release()
return np.stack(frames) metadata.frames_indices = indices
return np.stack(frames), metadata
def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): def read_video_decord(
video_path: str,
sample_indices_fn: Optional[Callable] = None,
**kwargs,
):
""" """
Decode the video with Decord decoder. Decode a video using the Decord backend.
Args: Args:
video_path (`str`): video_path (`str`):
Path to the video file. Path to the video file.
num_frames (`int`, *optional*): sample_indices_fn (`Callable`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
If not specified and `fps==None`, all frames are sampled. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
fps (`int`, *optional*): If not provided, simple uniform sampling with fps is performed.
Number of frames to sample per second. Should be passed only when `num_frames=None`. Example:
If not specified and `num_frames==None`, all frames are sampled. def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns: Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
""" """
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps() video_fps = vr.get_avg_fps()
total_num_frames = len(vr) total_num_frames = len(vr)
if num_frames is None and fps is not None: duration = total_num_frames / video_fps if video_fps else 0
num_frames = int(total_num_frames / video_fps * fps) metadata = VideoMetadata(
if num_frames > total_num_frames: total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
raise ValueError( )
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" indices = sample_indices_fn(metadata=metadata, **kwargs)
)
indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames)
frames = vr.get_batch(indices).asnumpy() frames = vr.get_batch(indices).asnumpy()
return frames metadata.frames_indices = indices
return frames, metadata
def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): def read_video_pyav(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
""" """
Decode the video with PyAV decoder. Decode the video with PyAV decoder.
Args: Args:
video_path (`str`): video_path (`str`):
Path to the video file. Path to the video file.
num_frames (`int`, *optional*): sample_indices_fn (`Callable`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
If not specified and `fps==None`, all frames are sampled. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
fps (`int`, *optional*): If not provided, simple uniform sampling with fps is performed.
Number of frames to sample per second. Should be passed only when `num_frames=None`. Example:
If not specified and `num_frames==None`, all frames are sampled. def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns: Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
""" """
container = av.open(video_path) container = av.open(video_path)
total_num_frames = container.streams.video[0].frames total_num_frames = container.streams.video[0].frames
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
if num_frames is None and fps is not None: duration = total_num_frames / video_fps if video_fps else 0
num_frames = int(total_num_frames / video_fps * fps) metadata = VideoMetadata(
if num_frames > total_num_frames: total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
raise ValueError( )
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." indices = sample_indices_fn(metadata=metadata, **kwargs)
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
)
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
frames = [] frames = []
container.seek(0) container.seek(0)
@@ -677,48 +717,58 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Opti
break break
if i >= 0 and i in indices: if i >= 0 and i in indices:
frames.append(frame) frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
metadata.frames_indices = indices
return video, metadata
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): def read_video_torchvision(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
""" """
Decode the video with torchvision decoder. Decode the video with torchvision decoder.
Args: Args:
video_path (`str`): video_path (`str`):
Path to the video file. Path to the video file.
num_frames (`int`, *optional*): sample_indices_fn (`Callable`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
If not specified and `fps==None`, all frames are sampled. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
fps (`int`, *optional*): If not provided, simple uniform sampling with fps is performed.
Number of frames to sample per second. Should be passed only when `num_frames=None`. Example:
If not specified and `num_frames==None`, all frames are sampled. def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns: Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
""" """
video, _, info = torchvision_io.read_video( video, _, info = torchvision_io.read_video(
video_path, video_path,
start_pts=0.0, start_pts=0.0,
end_pts=None, end_pts=None,
pts_unit="sec", pts_unit="sec",
output_format="TCHW", output_format="THWC",
) )
video_fps = info["video_fps"] video_fps = info["video_fps"]
total_num_frames = video.size(0) - 1 total_num_frames = video.size(0)
if num_frames is None and fps is not None: duration = total_num_frames / video_fps if video_fps else 0
num_frames = int(total_num_frames / video_fps * fps) metadata = VideoMetadata(
if num_frames > total_num_frames: total_num_frames=int(total_num_frames),
raise ValueError( fps=float(video_fps),
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." duration=float(duration),
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" video_backend="torchvision",
) )
if num_frames is not None: indices = sample_indices_fn(metadata=metadata, **kwargs)
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
return video[idx]
return video video = video[indices].contiguous().numpy()
metadata.frames_indices = indices
return video, metadata
VIDEO_DECODERS = { VIDEO_DECODERS = {
@@ -734,6 +784,8 @@ def load_video(
num_frames: Optional[int] = None, num_frames: Optional[int] = None,
fps: Optional[int] = None, fps: Optional[int] = None,
backend: str = "opencv", backend: str = "opencv",
sample_indices_fn: Optional[Callable] = None,
**kwargs,
) -> np.array: ) -> np.array:
""" """
Loads `video` to a numpy array. Loads `video` to a numpy array.
@@ -748,13 +800,36 @@ def load_video(
If not specified and `num_frames==None`, all frames are sampled. If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`): backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv". The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
indices at which the video should be sampled. For example:
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns: Returns:
`np.array`: A numpy array of shape (num_frames, channels, height, width). Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
""" """
if fps is not None and num_frames is not None: # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") if fps is not None and num_frames is not None and sample_indices_fn is None:
raise ValueError(
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
)
# If user didn't pass a sampling function, create one on the fly with default logic
if sample_indices_fn is None:
def sample_indices_fn_func(metadata, **fn_kwargs):
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
sample_indices_fn = sample_indices_fn_func
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"): if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
if not is_yt_dlp_available(): if not is_yt_dlp_available():
@@ -796,8 +871,8 @@ def load_video(
) )
video_decoder = VIDEO_DECODERS[backend] video_decoder = VIDEO_DECODERS[backend]
video = video_decoder(file_obj, num_frames=num_frames, fps=fps) video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
return video return video, metadata
def load_images( def load_images(

View File

@@ -24,13 +24,21 @@ import sys
import typing import typing
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union
import numpy as np import numpy as np
import typing_extensions import typing_extensions
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image, load_video from .image_utils import (
ChannelDimension,
ImageInput,
VideoInput,
is_valid_image,
is_vision_available,
load_image,
load_video,
)
if is_vision_available(): if is_vision_available():
@@ -339,14 +347,10 @@ class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, Comm
} }
class ChatTemplateKwargs(TypedDict, total=False): class TokenizerChatTemplateKwargs(TypedDict, total=False):
""" """
Keyword arguments for processor chat templates. Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tools (`List[Dict]`, *optional*): tools (`List[Dict]`, *optional*):
A list of tools (callable functions) that will be accessible to the model. If the template does not A list of tools (callable functions) that will be accessible to the model. If the template does not
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
@@ -373,6 +377,23 @@ class ChatTemplateKwargs(TypedDict, total=False):
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
the mask will contain 1. For user and system tokens, the mask will contain 0. the mask will contain 1. For user and system tokens, the mask will contain 0.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword. This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
"""
tools: Optional[List[Dict]] = None
documents: Optional[List[Dict[str, str]]] = None
add_generation_prompt: Optional[bool] = False
continue_final_message: Optional[bool] = False
return_assistant_tokens_mask: Optional[bool] = False
class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
"""
Keyword arguments for processor chat templates.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
num_frames (`int`, *optional*): num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded. Number of frames to sample uniformly. If not passed, the whole video is loaded.
video_load_backend (`str`, *optional*, defaults to `"pyav"`): video_load_backend (`str`, *optional*, defaults to `"pyav"`):
@@ -382,22 +403,28 @@ class ChatTemplateKwargs(TypedDict, total=False):
video_fps (`int`, *optional*): video_fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`. Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled. If not specified and `num_frames==None`, all frames are sampled.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
indices at which the video should be sampled. For example:
def sample_indices_fn(num_frames, fps, metadata, **kwargs):
# add you sampling logic here ...
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
""" """
tokenize: Optional[bool] = False tokenize: Optional[bool] = False
return_dict: Optional[bool] = False return_dict: Optional[bool] = False
tools: Optional[List[Dict]] = None
documents: Optional[List[Dict[str, str]]] = None
add_generation_prompt: Optional[bool] = False
continue_final_message: Optional[bool] = False
return_assistant_tokens_mask: Optional[bool] = False
num_frames: Optional[int] = None num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav" video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None video_fps: Optional[int] = None
sample_indices_fn: Optional[Callable] = None
class AllKwargsForChatTemplate( class AllKwargsForChatTemplate(
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ChatTemplateKwargs TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
): ... ): ...
@@ -1165,9 +1192,43 @@ class ProcessorMixin(PushToHubMixin):
) )
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)} return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
def _process_messages_for_chat_template(
self,
conversation: List[List[Dict[str, str]]],
batch_images: List[ImageInput],
batch_videos: List[VideoInput],
batch_video_metadata: List[List[Dict[str, any]]],
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
):
"""
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
were sampled. This information cannot be accessed before the video is loaded.
For most models it is a no-op, and must be overriden by model processors which require special processing.
Args:
conversation (`List[Dict, str, str]`):
The conversation to process. Always comes in batched format.
batch_images (`List[List[ImageInput]]`):
Batch of images that were loaded from url/path defined in the conversation. The images
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
per batch.
batch_videos (`List[List[ImageInput]]`):
Batch of videos that were loaded from url/path defined in the conversation. The videos
are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays
per batch.
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
per batch.
"""
return conversation
def apply_chat_template( def apply_chat_template(
self, self,
conversation: Union[List[Dict[str, str]]], conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
**kwargs: Unpack[AllKwargsForChatTemplate], **kwargs: Unpack[AllKwargsForChatTemplate],
) -> str: ) -> str:
@@ -1190,7 +1251,7 @@ class ProcessorMixin(PushToHubMixin):
] ]
Args: Args:
conversation (`List[Dict, str, str]`): conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
The conversation to format. The conversation to format.
chat_template (`Optional[str]`, *optional*): chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
@@ -1207,39 +1268,42 @@ class ProcessorMixin(PushToHubMixin):
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information." "https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
) )
# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
# and for multimodal chat template
tokenizer_template_kwargs = {}
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, tokenizer_value)
tokenizer_template_kwargs[tokenizer_key] = value
chat_template_kwargs = {} chat_template_kwargs = {}
for key in ChatTemplateKwargs.__annotations__.keys(): for key in ProcessorChatTemplateKwargs.__annotations__.keys():
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key)) processor_value = getattr(ProcessorChatTemplateKwargs, key, None)
value = kwargs.pop(key, processor_value)
chat_template_kwargs[key] = value chat_template_kwargs[key] = value
# Pop kwargs that should not be used by tokenizer's `apply_chat_template`
tokenize = chat_template_kwargs.pop("tokenize")
return_dict = chat_template_kwargs.pop("return_dict")
num_frames = chat_template_kwargs.pop("num_frames")
video_fps = chat_template_kwargs.pop("video_fps")
video_load_backend = chat_template_kwargs.pop("video_load_backend")
prompt = self.tokenizer.apply_chat_template(
conversation,
chat_template=chat_template,
tokenize=False,
return_dict=False,
**chat_template_kwargs,
)
if isinstance(conversation, (list, tuple)) and ( if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
): ):
conversations = conversation
is_batched = True is_batched = True
conversations = conversation
else: else:
conversations = [conversation]
is_batched = False is_batched = False
conversations = [conversation]
num_frames = chat_template_kwargs.get("num_frames")
video_fps = chat_template_kwargs.get("video_fps")
video_load_backend = chat_template_kwargs.get("video_load_backend")
tokenize = chat_template_kwargs.get("tokenize")
return_dict = chat_template_kwargs.get("return_dict")
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
if tokenize: if tokenize:
batch_images, batch_videos = [], [] batch_images, batch_videos = [], []
batch_video_metadata = []
for conversation in conversations: for conversation in conversations:
images, videos = [], [] images, videos = [], []
video_metadata = []
for message in conversation: for message in conversation:
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
image_fnames = [ image_fnames = [
@@ -1261,17 +1325,51 @@ class ProcessorMixin(PushToHubMixin):
video = [np.array(load_image(image_fname)).T for image_fname in fname] video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array # create a 4D video because `load_video` always returns a 4D array
video = np.stack(video) video = np.stack(video)
metadata = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
)
else: else:
video = load_video(fname, num_frames=num_frames, fps=video_fps, backend=video_load_backend) video, metadata = load_video(
fname,
num_frames=num_frames,
fps=video_fps,
backend=video_load_backend,
sample_indices_fn=sample_indices_fn,
)
videos.append(video) videos.append(video)
video_metadata.append(metadata)
# Currently all processors can accept accept nested list of batches, but not flat list of visuals # Currently all processors can accept nested list of batches, but not flat list of visuals
# So we'll make a batched list of images and let the processor handle it # So we'll make a batched list of images and let the processor handle it
if images: if images:
batch_images.append(images) batch_images.append(images)
if videos: if videos:
batch_videos.append(videos) batch_videos.append(videos)
batch_video_metadata.append(video_metadata)
# Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
conversations = self._process_messages_for_chat_template(
conversations,
batch_images=batch_images,
batch_videos=batch_videos,
batch_video_metadata=batch_video_metadata,
**chat_template_kwargs,
)
prompt = self.tokenizer.apply_chat_template(
conversations,
chat_template=chat_template,
tokenize=False,
return_dict=False,
**tokenizer_template_kwargs,
)
if not is_batched:
prompt = prompt[0]
if tokenize:
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
# and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling # and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling

View File

@@ -22,6 +22,7 @@ from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download
from transformers.models.auto.processing_auto import processor_class_from_name from transformers.models.auto.processing_auto import processor_class_from_name
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
@@ -538,7 +539,7 @@ class ProcessorTesterMixin:
def test_chat_template_save_loading(self): def test_chat_template_save_loading(self):
processor = self.get_processor() processor = self.get_processor()
signature = inspect.signature(processor.__call__) signature = inspect.signature(processor.__init__)
if "chat_template" not in {*signature.parameters.keys()}: if "chat_template" not in {*signature.parameters.keys()}:
self.skipTest("Processor doesn't accept chat templates at input") self.skipTest("Processor doesn't accept chat templates at input")
@@ -858,3 +859,119 @@ class ProcessorTesterMixin:
self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2) self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
@require_av
def test_chat_template_video_custom_sampling(self):
"""
Tests that models can pass their custom callables to sample video indices.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{
"type": "video",
"path": video_file_path,
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def dummmy_sample_indices_fn(metadata, **fn_kwargs):
# sample only the first two frame always
return [0, 1]
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=dummmy_sample_indices_fn,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
@require_av
def test_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)