From 15ec971b8ec999c6a511debe04ba32c115fb7413 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 14 Feb 2025 11:34:08 +0100 Subject: [PATCH] Prepare processors for VideoLLMs (#36149) * allow processor to preprocess conversation + video metadata * allow callable * add test * fix test * nit: fix * add metadata frames_indices * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * port updates from Orr and add one more test * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * typo * as dataclass * style * docstring + maek sure tests green --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- src/transformers/image_utils.py | 275 +++++++++++++++++---------- src/transformers/processing_utils.py | 172 +++++++++++++---- tests/test_processing_common.py | 119 +++++++++++- 3 files changed, 428 insertions(+), 138 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index ad439b5d9f..101de2c78a 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -18,7 +18,7 @@ import os from contextlib import redirect_stdout from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import requests @@ -126,6 +126,14 @@ class AnnotionFormat(ExplicitEnum): COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value +@dataclass +class VideoMetadata: + total_num_frames: int + fps: float + duration: float + video_backend: str + + AnnotationType = Dict[str, Union[int, str, List[Dict]]] @@ -541,62 +549,83 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = return image -def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None): +def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): """ - Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames` - when loading a video. + A default sampling function that replicates the logic used in get_uniform_frame_indices, + while optionally handling `fps` if `num_frames` is not provided. Args: - total_num_frames (`int`): - Total number of frames that a video has. + metadata (`VideoMetadata`): + `VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps". num_frames (`int`, *optional*): - Number of frames to sample uniformly. If not specified, all frames are sampled. - - Returns: - np.ndarray: np array of frame indices that will be sampled. - """ - if num_frames is not None: - indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int) - else: - indices = np.arange(0, total_num_frames).astype(int) - return indices - - -def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): - """ - Decode the video with open-cv decoder. - - Args: - video_path (`str`): - Path to the video file. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. Should be passed only when `fps=None`. - If not specified and `fps==None`, all frames are sampled. + Number of frames to sample uniformly. fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. + Desired frames per second. Takes priority over num_frames if both are provided. Returns: - np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). + `np.ndarray`: Array of frame indices to sample. """ - video = cv2.VideoCapture(video_path) - total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - video_fps = video.get(cv2.CAP_PROP_FPS) + total_num_frames = metadata.total_num_frames + video_fps = metadata.fps + + # If num_frames is not given but fps is, calculate num_frames from fps if num_frames is None and fps is not None: num_frames = int(total_num_frames / video_fps * fps) if num_frames > total_num_frames: raise ValueError( - f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." - f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" + f"When loading the video with fps={fps}, we computed num_frames={num_frames} " + f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata." ) - indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames) + + if num_frames is not None: + indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int) + else: + indices = np.arange(0, total_num_frames, dtype=int) + return indices + + +def read_video_opencv( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode a video using the OpenCV backend. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + Tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + video = cv2.VideoCapture(video_path) + total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + video_fps = video.get(cv2.CAP_PROP_FPS) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv" + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) index = 0 frames = [] while video.isOpened(): success, frame = video.read() + if not success: + break if index in indices: height, width, channel = frame.shape + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame[0:height, 0:width, 0:channel]) if success: index += 1 @@ -604,70 +633,81 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Op break video.release() - return np.stack(frames) + metadata.frames_indices = indices + return np.stack(frames), metadata -def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): +def read_video_decord( + video_path: str, + sample_indices_fn: Optional[Callable] = None, + **kwargs, +): """ - Decode the video with Decord decoder. + Decode a video using the Decord backend. Args: video_path (`str`): Path to the video file. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. Should be passed only when `fps=None`. - If not specified and `fps==None`, all frames are sampled. - fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) Returns: - np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). + Tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. """ vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu video_fps = vr.get_avg_fps() total_num_frames = len(vr) - if num_frames is None and fps is not None: - num_frames = int(total_num_frames / video_fps * fps) - if num_frames > total_num_frames: - raise ValueError( - f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." - f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" - ) - indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord" + ) + + indices = sample_indices_fn(metadata=metadata, **kwargs) + frames = vr.get_batch(indices).asnumpy() - return frames + metadata.frames_indices = indices + return frames, metadata -def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): +def read_video_pyav( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): """ Decode the video with PyAV decoder. Args: video_path (`str`): Path to the video file. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. Should be passed only when `fps=None`. - If not specified and `fps==None`, all frames are sampled. - fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) Returns: - np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). + Tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. """ container = av.open(video_path) - total_num_frames = container.streams.video[0].frames video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? - if num_frames is None and fps is not None: - num_frames = int(total_num_frames / video_fps * fps) - if num_frames > total_num_frames: - raise ValueError( - f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." - f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" - ) - indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav" + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) frames = [] container.seek(0) @@ -677,48 +717,58 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Opti break if i >= 0 and i in indices: frames.append(frame) - return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + video = np.stack([x.to_ndarray(format="rgb24") for x in frames]) + metadata.frames_indices = indices + return video, metadata -def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None): +def read_video_torchvision( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): """ Decode the video with torchvision decoder. Args: video_path (`str`): Path to the video file. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. Should be passed only when `fps=None`. - If not specified and `fps==None`, all frames are sampled. - fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) Returns: - np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). + Tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. """ video, _, info = torchvision_io.read_video( video_path, start_pts=0.0, end_pts=None, pts_unit="sec", - output_format="TCHW", + output_format="THWC", ) video_fps = info["video_fps"] - total_num_frames = video.size(0) - 1 - if num_frames is None and fps is not None: - num_frames = int(total_num_frames / video_fps * fps) - if num_frames > total_num_frames: - raise ValueError( - f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ." - f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}" - ) + total_num_frames = video.size(0) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="torchvision", + ) - if num_frames is not None: - idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64) - return video[idx] + indices = sample_indices_fn(metadata=metadata, **kwargs) - return video + video = video[indices].contiguous().numpy() + metadata.frames_indices = indices + return video, metadata VIDEO_DECODERS = { @@ -734,6 +784,8 @@ def load_video( num_frames: Optional[int] = None, fps: Optional[int] = None, backend: str = "opencv", + sample_indices_fn: Optional[Callable] = None, + **kwargs, ) -> np.array: """ Loads `video` to a numpy array. @@ -748,13 +800,36 @@ def load_video( If not specified and `num_frames==None`, all frames are sampled. backend (`str`, *optional*, defaults to `"opencv"`): The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv". + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. + The function expects at input the all args along with all kwargs passed to `load_video` and should output valid + indices at which the video should be sampled. For example: + + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) Returns: - `np.array`: A numpy array of shape (num_frames, channels, height, width). + Tuple[`np.array`, Dict]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - Metadata dictionary. """ - if fps is not None and num_frames is not None: - raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") + # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn` + if fps is not None and num_frames is not None and sample_indices_fn is None: + raise ValueError( + "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" + ) + + # If user didn't pass a sampling function, create one on the fly with default logic + if sample_indices_fn is None: + + def sample_indices_fn_func(metadata, **fn_kwargs): + return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs) + + sample_indices_fn = sample_indices_fn_func if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"): if not is_yt_dlp_available(): @@ -796,8 +871,8 @@ def load_video( ) video_decoder = VIDEO_DECODERS[backend] - video = video_decoder(file_obj, num_frames=num_frames, fps=fps) - return video + video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs) + return video, metadata def load_images( diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index f9a0c4c5e8..5b1e45259f 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -24,13 +24,21 @@ import sys import typing import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union import numpy as np import typing_extensions from .dynamic_module_utils import custom_object_save -from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image, load_video +from .image_utils import ( + ChannelDimension, + ImageInput, + VideoInput, + is_valid_image, + is_vision_available, + load_image, + load_video, +) if is_vision_available(): @@ -339,14 +347,10 @@ class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, Comm } -class ChatTemplateKwargs(TypedDict, total=False): +class TokenizerChatTemplateKwargs(TypedDict, total=False): """ - Keyword arguments for processor chat templates. + Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor. - tokenize (`bool`, *optional*, defaults to `False`): - Whether to tokenize the output or not. - return_dict (`bool`, defaults to `False`): - Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. tools (`List[Dict]`, *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, @@ -373,6 +377,23 @@ class ChatTemplateKwargs(TypedDict, total=False): Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, the mask will contain 1. For user and system tokens, the mask will contain 0. This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + """ + + tools: Optional[List[Dict]] = None + documents: Optional[List[Dict[str, str]]] = None + add_generation_prompt: Optional[bool] = False + continue_final_message: Optional[bool] = False + return_assistant_tokens_mask: Optional[bool] = False + + +class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False): + """ + Keyword arguments for processor chat templates. + + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. num_frames (`int`, *optional*): Number of frames to sample uniformly. If not passed, the whole video is loaded. video_load_backend (`str`, *optional*, defaults to `"pyav"`): @@ -382,22 +403,28 @@ class ChatTemplateKwargs(TypedDict, total=False): video_fps (`int`, *optional*): Number of frames to sample per second. Should be passed only when `num_frames=None`. If not specified and `num_frames==None`, all frames are sampled. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. + The function expects at input the all args along with all kwargs passed to `load_video` and should output valid + indices at which the video should be sampled. For example: + + def sample_indices_fn(num_frames, fps, metadata, **kwargs): + # add you sampling logic here ... + return np.linspace(start_idx, end_idx, num_frames, dtype=int) """ tokenize: Optional[bool] = False return_dict: Optional[bool] = False - tools: Optional[List[Dict]] = None - documents: Optional[List[Dict[str, str]]] = None - add_generation_prompt: Optional[bool] = False - continue_final_message: Optional[bool] = False - return_assistant_tokens_mask: Optional[bool] = False num_frames: Optional[int] = None video_load_backend: Optional[str] = "pyav" video_fps: Optional[int] = None + sample_indices_fn: Optional[Callable] = None class AllKwargsForChatTemplate( - TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ChatTemplateKwargs + TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs ): ... @@ -1165,9 +1192,43 @@ class ProcessorMixin(PushToHubMixin): ) return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)} + def _process_messages_for_chat_template( + self, + conversation: List[List[Dict[str, str]]], + batch_images: List[ImageInput], + batch_videos: List[VideoInput], + batch_video_metadata: List[List[Dict[str, any]]], + **chat_template_kwargs: Unpack[AllKwargsForChatTemplate], + ): + """ + Used within `apply_chat_template` when a model has a special way to process conversation history. For example, + video models might want to specify in the prompt the duration of video or which frame indices at which timestamps + were sampled. This information cannot be accessed before the video is loaded. + + For most models it is a no-op, and must be overriden by model processors which require special processing. + + Args: + conversation (`List[Dict, str, str]`): + The conversation to process. Always comes in batched format. + batch_images (`List[List[ImageInput]]`): + Batch of images that were loaded from url/path defined in the conversation. The images + are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images + per batch. + batch_videos (`List[List[ImageInput]]`): + Batch of videos that were loaded from url/path defined in the conversation. The videos + are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays + per batch. + batch_video_metadata (`List[List[Dict[[str, any]]]]`): + Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video. + Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays + per batch. + + """ + return conversation + def apply_chat_template( self, - conversation: Union[List[Dict[str, str]]], + conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], chat_template: Optional[str] = None, **kwargs: Unpack[AllKwargsForChatTemplate], ) -> str: @@ -1190,7 +1251,7 @@ class ProcessorMixin(PushToHubMixin): ] Args: - conversation (`List[Dict, str, str]`): + conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`): The conversation to format. chat_template (`Optional[str]`, *optional*): The Jinja template to use for formatting the conversation. If not provided, the tokenizer's @@ -1207,39 +1268,42 @@ class ProcessorMixin(PushToHubMixin): "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." ) + # Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template` + # and for multimodal chat template + tokenizer_template_kwargs = {} + for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys(): + tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None) + value = kwargs.pop(tokenizer_key, tokenizer_value) + tokenizer_template_kwargs[tokenizer_key] = value + chat_template_kwargs = {} - for key in ChatTemplateKwargs.__annotations__.keys(): - value = kwargs.pop(key, getattr(ChatTemplateKwargs, key)) + for key in ProcessorChatTemplateKwargs.__annotations__.keys(): + processor_value = getattr(ProcessorChatTemplateKwargs, key, None) + value = kwargs.pop(key, processor_value) chat_template_kwargs[key] = value - # Pop kwargs that should not be used by tokenizer's `apply_chat_template` - tokenize = chat_template_kwargs.pop("tokenize") - return_dict = chat_template_kwargs.pop("return_dict") - num_frames = chat_template_kwargs.pop("num_frames") - video_fps = chat_template_kwargs.pop("video_fps") - video_load_backend = chat_template_kwargs.pop("video_load_backend") - - prompt = self.tokenizer.apply_chat_template( - conversation, - chat_template=chat_template, - tokenize=False, - return_dict=False, - **chat_template_kwargs, - ) - if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") ): - conversations = conversation is_batched = True + conversations = conversation else: - conversations = [conversation] is_batched = False + conversations = [conversation] + + num_frames = chat_template_kwargs.get("num_frames") + video_fps = chat_template_kwargs.get("video_fps") + video_load_backend = chat_template_kwargs.get("video_load_backend") + tokenize = chat_template_kwargs.get("tokenize") + return_dict = chat_template_kwargs.get("return_dict") + sample_indices_fn = chat_template_kwargs.get("sample_indices_fn") if tokenize: batch_images, batch_videos = [], [] + batch_video_metadata = [] for conversation in conversations: images, videos = [], [] + video_metadata = [] for message in conversation: visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] image_fnames = [ @@ -1261,17 +1325,51 @@ class ProcessorMixin(PushToHubMixin): video = [np.array(load_image(image_fname)).T for image_fname in fname] # create a 4D video because `load_video` always returns a 4D array video = np.stack(video) + metadata = None + logger.warning( + "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " + "If you model applies special processing based on metadata, please load the whole video and let the model sample frames." + ) else: - video = load_video(fname, num_frames=num_frames, fps=video_fps, backend=video_load_backend) + video, metadata = load_video( + fname, + num_frames=num_frames, + fps=video_fps, + backend=video_load_backend, + sample_indices_fn=sample_indices_fn, + ) videos.append(video) + video_metadata.append(metadata) - # Currently all processors can accept accept nested list of batches, but not flat list of visuals + # Currently all processors can accept nested list of batches, but not flat list of visuals # So we'll make a batched list of images and let the processor handle it if images: batch_images.append(images) if videos: batch_videos.append(videos) + batch_video_metadata.append(video_metadata) + # Process conversation with video/image information if needed. Then convert into a prompt using Jinja template + conversations = self._process_messages_for_chat_template( + conversations, + batch_images=batch_images, + batch_videos=batch_videos, + batch_video_metadata=batch_video_metadata, + **chat_template_kwargs, + ) + + prompt = self.tokenizer.apply_chat_template( + conversations, + chat_template=chat_template, + tokenize=False, + return_dict=False, + **tokenizer_template_kwargs, + ) + + if not is_batched: + prompt = prompt[0] + + if tokenize: # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt # and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index fd3345cbbd..28aff79c7e 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -22,6 +22,7 @@ from pathlib import Path from typing import Optional import numpy as np +from huggingface_hub import hf_hub_download from transformers.models.auto.processing_auto import processor_class_from_name from transformers.processing_utils import Unpack @@ -538,7 +539,7 @@ class ProcessorTesterMixin: def test_chat_template_save_loading(self): processor = self.get_processor() - signature = inspect.signature(processor.__call__) + signature = inspect.signature(processor.__init__) if "chat_template" not in {*signature.parameters.keys()}: self.skipTest("Processor doesn't accept chat templates at input") @@ -858,3 +859,119 @@ class ProcessorTesterMixin: self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2) + + @require_av + def test_chat_template_video_custom_sampling(self): + """ + Tests that models can pass their custom callables to sample video indices. + """ + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + signature = inspect.signature(processor.__call__) + if "videos" not in {*signature.parameters.keys()} or ( + signature.parameters.get("videos") is not None + and signature.parameters["videos"].annotation == inspect._empty + ): + self.skipTest("Processor doesn't accept videos at input") + + video_file_path = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset" + ) + messages = [ + [ + { + "role": "user", + "content": [ + { + "type": "video", + "path": video_file_path, + }, + {"type": "text", "text": "What is shown in this video?"}, + ], + }, + ] + ] + + def dummmy_sample_indices_fn(metadata, **fn_kwargs): + # sample only the first two frame always + return [0, 1] + + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + sample_indices_fn=dummmy_sample_indices_fn, + ) + self.assertTrue(self.videos_input_name in out_dict_with_video) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) + self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2) + + @require_av + def test_chat_template_video_special_processing(self): + """ + Tests that models can use their own preprocessing to preprocess conversations. + """ + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + signature = inspect.signature(processor.__call__) + if "videos" not in {*signature.parameters.keys()} or ( + signature.parameters.get("videos") is not None + and signature.parameters["videos"].annotation == inspect._empty + ): + self.skipTest("Processor doesn't accept videos at input") + + video_file_path = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset" + ) + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "video", "path": video_file_path}, + {"type": "text", "text": "What is shown in this video?"}, + ], + }, + ] + ] + + def _process_messages_for_chat_template( + conversation, + batch_images, + batch_videos, + batch_video_metadata, + **chat_template_kwargs, + ): + # Let us just always return a dummy prompt + new_msg = [ + [ + { + "role": "user", + "content": [ + {"type": "video"}, # no need to use path, video is loaded already by this moment + {"type": "text", "text": "Dummy prompt for preprocess testing"}, + ], + }, + ] + ] + return new_msg + + processor._process_messages_for_chat_template = _process_messages_for_chat_template + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + ) + self.assertTrue(self.videos_input_name in out_dict_with_video) + + # Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc + formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0] + self.assertTrue("Dummy prompt for preprocess testing" in formatted_text) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) + self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)