Prepare processors for VideoLLMs (#36149)
* allow processor to preprocess conversation + video metadata * allow callable * add test * fix test * nit: fix * add metadata frames_indices * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * port updates from Orr and add one more test * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * typo * as dataclass * style * docstring + maek sure tests green --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
33d1d715b0
commit
15ec971b8e
@@ -18,7 +18,7 @@ import os
|
||||
from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -126,6 +126,14 @@ class AnnotionFormat(ExplicitEnum):
|
||||
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
total_num_frames: int
|
||||
fps: float
|
||||
duration: float
|
||||
video_backend: str
|
||||
|
||||
|
||||
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
||||
|
||||
|
||||
@@ -541,62 +549,83 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
||||
return image
|
||||
|
||||
|
||||
def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
|
||||
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
|
||||
"""
|
||||
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
|
||||
when loading a video.
|
||||
A default sampling function that replicates the logic used in get_uniform_frame_indices,
|
||||
while optionally handling `fps` if `num_frames` is not provided.
|
||||
|
||||
Args:
|
||||
total_num_frames (`int`):
|
||||
Total number of frames that a video has.
|
||||
metadata (`VideoMetadata`):
|
||||
`VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of frame indices that will be sampled.
|
||||
"""
|
||||
if num_frames is not None:
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
|
||||
else:
|
||||
indices = np.arange(0, total_num_frames).astype(int)
|
||||
return indices
|
||||
|
||||
|
||||
def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
"""
|
||||
Decode the video with open-cv decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
Number of frames to sample uniformly.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
Desired frames per second. Takes priority over num_frames if both are provided.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
`np.ndarray`: Array of frame indices to sample.
|
||||
"""
|
||||
video = cv2.VideoCapture(video_path)
|
||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||
total_num_frames = metadata.total_num_frames
|
||||
video_fps = metadata.fps
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
|
||||
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
|
||||
)
|
||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
||||
|
||||
if num_frames is not None:
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
|
||||
else:
|
||||
indices = np.arange(0, total_num_frames, dtype=int)
|
||||
return indices
|
||||
|
||||
|
||||
def read_video_opencv(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the OpenCV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
video = cv2.VideoCapture(video_path)
|
||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
index = 0
|
||||
frames = []
|
||||
while video.isOpened():
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
if index in indices:
|
||||
height, width, channel = frame.shape
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame[0:height, 0:width, 0:channel])
|
||||
if success:
|
||||
index += 1
|
||||
@@ -604,70 +633,81 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Op
|
||||
break
|
||||
|
||||
video.release()
|
||||
return np.stack(frames)
|
||||
metadata.frames_indices = indices
|
||||
return np.stack(frames), metadata
|
||||
|
||||
|
||||
def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
def read_video_decord(
|
||||
video_path: str,
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with Decord decoder.
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
|
||||
)
|
||||
indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = vr.get_batch(indices).asnumpy()
|
||||
return frames
|
||||
metadata.frames_indices = indices
|
||||
return frames, metadata
|
||||
|
||||
|
||||
def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
def read_video_pyav(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with PyAV decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
container = av.open(video_path)
|
||||
|
||||
total_num_frames = container.streams.video[0].frames
|
||||
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
|
||||
)
|
||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = []
|
||||
container.seek(0)
|
||||
@@ -677,48 +717,58 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Opti
|
||||
break
|
||||
if i >= 0 and i in indices:
|
||||
frames.append(frame)
|
||||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
|
||||
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
||||
def read_video_torchvision(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with torchvision decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
||||
If not specified and `fps==None`, all frames are sampled.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
video, _, info = torchvision_io.read_video(
|
||||
video_path,
|
||||
start_pts=0.0,
|
||||
end_pts=None,
|
||||
pts_unit="sec",
|
||||
output_format="TCHW",
|
||||
output_format="THWC",
|
||||
)
|
||||
video_fps = info["video_fps"]
|
||||
total_num_frames = video.size(0) - 1
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
||||
total_num_frames = video.size(0)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="torchvision",
|
||||
)
|
||||
|
||||
if num_frames is not None:
|
||||
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
|
||||
return video[idx]
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
return video
|
||||
video = video[indices].contiguous().numpy()
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
@@ -734,6 +784,8 @@ def load_video(
|
||||
num_frames: Optional[int] = None,
|
||||
fps: Optional[int] = None,
|
||||
backend: str = "opencv",
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> np.array:
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
@@ -748,13 +800,36 @@ def load_video(
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||
indices at which the video should be sampled. For example:
|
||||
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
`np.array`: A numpy array of shape (num_frames, channels, height, width).
|
||||
Tuple[`np.array`, Dict]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- Metadata dictionary.
|
||||
"""
|
||||
|
||||
if fps is not None and num_frames is not None:
|
||||
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
|
||||
if fps is not None and num_frames is not None and sample_indices_fn is None:
|
||||
raise ValueError(
|
||||
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
||||
)
|
||||
|
||||
# If user didn't pass a sampling function, create one on the fly with default logic
|
||||
if sample_indices_fn is None:
|
||||
|
||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
|
||||
|
||||
sample_indices_fn = sample_indices_fn_func
|
||||
|
||||
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
|
||||
if not is_yt_dlp_available():
|
||||
@@ -796,8 +871,8 @@ def load_video(
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video = video_decoder(file_obj, num_frames=num_frames, fps=fps)
|
||||
return video
|
||||
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def load_images(
|
||||
|
||||
@@ -24,13 +24,21 @@ import sys
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import typing_extensions
|
||||
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image, load_video
|
||||
from .image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
VideoInput,
|
||||
is_valid_image,
|
||||
is_vision_available,
|
||||
load_image,
|
||||
load_video,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -339,14 +347,10 @@ class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, Comm
|
||||
}
|
||||
|
||||
|
||||
class ChatTemplateKwargs(TypedDict, total=False):
|
||||
class TokenizerChatTemplateKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for processor chat templates.
|
||||
Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor.
|
||||
|
||||
tokenize (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tokenize the output or not.
|
||||
return_dict (`bool`, defaults to `False`):
|
||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||
tools (`List[Dict]`, *optional*):
|
||||
A list of tools (callable functions) that will be accessible to the model. If the template does not
|
||||
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
|
||||
@@ -373,6 +377,23 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
||||
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
|
||||
the mask will contain 1. For user and system tokens, the mask will contain 0.
|
||||
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
||||
"""
|
||||
|
||||
tools: Optional[List[Dict]] = None
|
||||
documents: Optional[List[Dict[str, str]]] = None
|
||||
add_generation_prompt: Optional[bool] = False
|
||||
continue_final_message: Optional[bool] = False
|
||||
return_assistant_tokens_mask: Optional[bool] = False
|
||||
|
||||
|
||||
class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
|
||||
"""
|
||||
Keyword arguments for processor chat templates.
|
||||
|
||||
tokenize (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tokenize the output or not.
|
||||
return_dict (`bool`, defaults to `False`):
|
||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
|
||||
@@ -382,22 +403,28 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
||||
video_fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||
indices at which the video should be sampled. For example:
|
||||
|
||||
def sample_indices_fn(num_frames, fps, metadata, **kwargs):
|
||||
# add you sampling logic here ...
|
||||
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
|
||||
"""
|
||||
|
||||
tokenize: Optional[bool] = False
|
||||
return_dict: Optional[bool] = False
|
||||
tools: Optional[List[Dict]] = None
|
||||
documents: Optional[List[Dict[str, str]]] = None
|
||||
add_generation_prompt: Optional[bool] = False
|
||||
continue_final_message: Optional[bool] = False
|
||||
return_assistant_tokens_mask: Optional[bool] = False
|
||||
num_frames: Optional[int] = None
|
||||
video_load_backend: Optional[str] = "pyav"
|
||||
video_fps: Optional[int] = None
|
||||
sample_indices_fn: Optional[Callable] = None
|
||||
|
||||
|
||||
class AllKwargsForChatTemplate(
|
||||
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ChatTemplateKwargs
|
||||
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
|
||||
): ...
|
||||
|
||||
|
||||
@@ -1165,9 +1192,43 @@ class ProcessorMixin(PushToHubMixin):
|
||||
)
|
||||
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
|
||||
|
||||
def _process_messages_for_chat_template(
|
||||
self,
|
||||
conversation: List[List[Dict[str, str]]],
|
||||
batch_images: List[ImageInput],
|
||||
batch_videos: List[VideoInput],
|
||||
batch_video_metadata: List[List[Dict[str, any]]],
|
||||
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
|
||||
):
|
||||
"""
|
||||
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
|
||||
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
|
||||
were sampled. This information cannot be accessed before the video is loaded.
|
||||
|
||||
For most models it is a no-op, and must be overriden by model processors which require special processing.
|
||||
|
||||
Args:
|
||||
conversation (`List[Dict, str, str]`):
|
||||
The conversation to process. Always comes in batched format.
|
||||
batch_images (`List[List[ImageInput]]`):
|
||||
Batch of images that were loaded from url/path defined in the conversation. The images
|
||||
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
|
||||
per batch.
|
||||
batch_videos (`List[List[ImageInput]]`):
|
||||
Batch of videos that were loaded from url/path defined in the conversation. The videos
|
||||
are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays
|
||||
per batch.
|
||||
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
|
||||
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
|
||||
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
|
||||
per batch.
|
||||
|
||||
"""
|
||||
return conversation
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]]],
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs: Unpack[AllKwargsForChatTemplate],
|
||||
) -> str:
|
||||
@@ -1190,7 +1251,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
]
|
||||
|
||||
Args:
|
||||
conversation (`List[Dict, str, str]`):
|
||||
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
|
||||
The conversation to format.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||
@@ -1207,39 +1268,42 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
||||
)
|
||||
|
||||
# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
|
||||
# and for multimodal chat template
|
||||
tokenizer_template_kwargs = {}
|
||||
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
|
||||
tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
|
||||
value = kwargs.pop(tokenizer_key, tokenizer_value)
|
||||
tokenizer_template_kwargs[tokenizer_key] = value
|
||||
|
||||
chat_template_kwargs = {}
|
||||
for key in ChatTemplateKwargs.__annotations__.keys():
|
||||
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key))
|
||||
for key in ProcessorChatTemplateKwargs.__annotations__.keys():
|
||||
processor_value = getattr(ProcessorChatTemplateKwargs, key, None)
|
||||
value = kwargs.pop(key, processor_value)
|
||||
chat_template_kwargs[key] = value
|
||||
|
||||
# Pop kwargs that should not be used by tokenizer's `apply_chat_template`
|
||||
tokenize = chat_template_kwargs.pop("tokenize")
|
||||
return_dict = chat_template_kwargs.pop("return_dict")
|
||||
num_frames = chat_template_kwargs.pop("num_frames")
|
||||
video_fps = chat_template_kwargs.pop("video_fps")
|
||||
video_load_backend = chat_template_kwargs.pop("video_load_backend")
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=chat_template,
|
||||
tokenize=False,
|
||||
return_dict=False,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(conversation, (list, tuple)) and (
|
||||
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
|
||||
):
|
||||
conversations = conversation
|
||||
is_batched = True
|
||||
conversations = conversation
|
||||
else:
|
||||
conversations = [conversation]
|
||||
is_batched = False
|
||||
conversations = [conversation]
|
||||
|
||||
num_frames = chat_template_kwargs.get("num_frames")
|
||||
video_fps = chat_template_kwargs.get("video_fps")
|
||||
video_load_backend = chat_template_kwargs.get("video_load_backend")
|
||||
tokenize = chat_template_kwargs.get("tokenize")
|
||||
return_dict = chat_template_kwargs.get("return_dict")
|
||||
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
|
||||
|
||||
if tokenize:
|
||||
batch_images, batch_videos = [], []
|
||||
batch_video_metadata = []
|
||||
for conversation in conversations:
|
||||
images, videos = [], []
|
||||
video_metadata = []
|
||||
for message in conversation:
|
||||
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
||||
image_fnames = [
|
||||
@@ -1261,17 +1325,51 @@ class ProcessorMixin(PushToHubMixin):
|
||||
video = [np.array(load_image(image_fname)).T for image_fname in fname]
|
||||
# create a 4D video because `load_video` always returns a 4D array
|
||||
video = np.stack(video)
|
||||
metadata = None
|
||||
logger.warning(
|
||||
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
|
||||
"If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
|
||||
)
|
||||
else:
|
||||
video = load_video(fname, num_frames=num_frames, fps=video_fps, backend=video_load_backend)
|
||||
video, metadata = load_video(
|
||||
fname,
|
||||
num_frames=num_frames,
|
||||
fps=video_fps,
|
||||
backend=video_load_backend,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
)
|
||||
videos.append(video)
|
||||
video_metadata.append(metadata)
|
||||
|
||||
# Currently all processors can accept accept nested list of batches, but not flat list of visuals
|
||||
# Currently all processors can accept nested list of batches, but not flat list of visuals
|
||||
# So we'll make a batched list of images and let the processor handle it
|
||||
if images:
|
||||
batch_images.append(images)
|
||||
if videos:
|
||||
batch_videos.append(videos)
|
||||
batch_video_metadata.append(video_metadata)
|
||||
|
||||
# Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
|
||||
conversations = self._process_messages_for_chat_template(
|
||||
conversations,
|
||||
batch_images=batch_images,
|
||||
batch_videos=batch_videos,
|
||||
batch_video_metadata=batch_video_metadata,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
chat_template=chat_template,
|
||||
tokenize=False,
|
||||
return_dict=False,
|
||||
**tokenizer_template_kwargs,
|
||||
)
|
||||
|
||||
if not is_batched:
|
||||
prompt = prompt[0]
|
||||
|
||||
if tokenize:
|
||||
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
|
||||
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
|
||||
# and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling
|
||||
|
||||
@@ -22,6 +22,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers.models.auto.processing_auto import processor_class_from_name
|
||||
from transformers.processing_utils import Unpack
|
||||
@@ -538,7 +539,7 @@ class ProcessorTesterMixin:
|
||||
|
||||
def test_chat_template_save_loading(self):
|
||||
processor = self.get_processor()
|
||||
signature = inspect.signature(processor.__call__)
|
||||
signature = inspect.signature(processor.__init__)
|
||||
if "chat_template" not in {*signature.parameters.keys()}:
|
||||
self.skipTest("Processor doesn't accept chat templates at input")
|
||||
|
||||
@@ -858,3 +859,119 @@ class ProcessorTesterMixin:
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_video_custom_sampling(self):
|
||||
"""
|
||||
Tests that models can pass their custom callables to sample video indices.
|
||||
"""
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"path": video_file_path,
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
def dummmy_sample_indices_fn(metadata, **fn_kwargs):
|
||||
# sample only the first two frame always
|
||||
return [0, 1]
|
||||
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
sample_indices_fn=dummmy_sample_indices_fn,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_video_special_processing(self):
|
||||
"""
|
||||
Tests that models can use their own preprocessing to preprocess conversations.
|
||||
"""
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": video_file_path},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
def _process_messages_for_chat_template(
|
||||
conversation,
|
||||
batch_images,
|
||||
batch_videos,
|
||||
batch_video_metadata,
|
||||
**chat_template_kwargs,
|
||||
):
|
||||
# Let us just always return a dummy prompt
|
||||
new_msg = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
return new_msg
|
||||
|
||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
|
||||
|
||||
Reference in New Issue
Block a user