Prepare processors for VideoLLMs (#36149)
* allow processor to preprocess conversation + video metadata * allow callable * add test * fix test * nit: fix * add metadata frames_indices * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * port updates from Orr and add one more test * Update src/transformers/processing_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * typo * as dataclass * style * docstring + maek sure tests green --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
33d1d715b0
commit
15ec971b8e
@@ -18,7 +18,7 @@ import os
|
|||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -126,6 +126,14 @@ class AnnotionFormat(ExplicitEnum):
|
|||||||
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoMetadata:
|
||||||
|
total_num_frames: int
|
||||||
|
fps: float
|
||||||
|
duration: float
|
||||||
|
video_backend: str
|
||||||
|
|
||||||
|
|
||||||
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
||||||
|
|
||||||
|
|
||||||
@@ -541,62 +549,83 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
|
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
|
A default sampling function that replicates the logic used in get_uniform_frame_indices,
|
||||||
when loading a video.
|
while optionally handling `fps` if `num_frames` is not provided.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
total_num_frames (`int`):
|
metadata (`VideoMetadata`):
|
||||||
Total number of frames that a video has.
|
`VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
Number of frames to sample uniformly.
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: np array of frame indices that will be sampled.
|
|
||||||
"""
|
|
||||||
if num_frames is not None:
|
|
||||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
|
|
||||||
else:
|
|
||||||
indices = np.arange(0, total_num_frames).astype(int)
|
|
||||||
return indices
|
|
||||||
|
|
||||||
|
|
||||||
def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
|
||||||
"""
|
|
||||||
Decode the video with open-cv decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_path (`str`):
|
|
||||||
Path to the video file.
|
|
||||||
num_frames (`int`, *optional*):
|
|
||||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
|
||||||
If not specified and `fps==None`, all frames are sampled.
|
|
||||||
fps (`int`, *optional*):
|
fps (`int`, *optional*):
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
Desired frames per second. Takes priority over num_frames if both are provided.
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
`np.ndarray`: Array of frame indices to sample.
|
||||||
"""
|
"""
|
||||||
video = cv2.VideoCapture(video_path)
|
total_num_frames = metadata.total_num_frames
|
||||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
video_fps = metadata.fps
|
||||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
|
||||||
|
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||||
if num_frames is None and fps is not None:
|
if num_frames is None and fps is not None:
|
||||||
num_frames = int(total_num_frames / video_fps * fps)
|
num_frames = int(total_num_frames / video_fps * fps)
|
||||||
if num_frames > total_num_frames:
|
if num_frames > total_num_frames:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
|
||||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
|
||||||
)
|
)
|
||||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
|
||||||
|
if num_frames is not None:
|
||||||
|
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
|
||||||
|
else:
|
||||||
|
indices = np.arange(0, total_num_frames, dtype=int)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_opencv(
|
||||||
|
video_path: str,
|
||||||
|
sample_indices_fn: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode a video using the OpenCV backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path (`str`):
|
||||||
|
Path to the video file.
|
||||||
|
sample_indices_fn (`Callable`):
|
||||||
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
|
If not provided, simple uniform sampling with fps is performed.
|
||||||
|
Example:
|
||||||
|
def sample_indices_fn(metadata, **kwargs):
|
||||||
|
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||||
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
|
- `VideoMetadata` object.
|
||||||
|
"""
|
||||||
|
video = cv2.VideoCapture(video_path)
|
||||||
|
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||||
|
duration = total_num_frames / video_fps if video_fps else 0
|
||||||
|
metadata = VideoMetadata(
|
||||||
|
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
|
||||||
|
)
|
||||||
|
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
frames = []
|
frames = []
|
||||||
while video.isOpened():
|
while video.isOpened():
|
||||||
success, frame = video.read()
|
success, frame = video.read()
|
||||||
|
if not success:
|
||||||
|
break
|
||||||
if index in indices:
|
if index in indices:
|
||||||
height, width, channel = frame.shape
|
height, width, channel = frame.shape
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
frames.append(frame[0:height, 0:width, 0:channel])
|
frames.append(frame[0:height, 0:width, 0:channel])
|
||||||
if success:
|
if success:
|
||||||
index += 1
|
index += 1
|
||||||
@@ -604,70 +633,81 @@ def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Op
|
|||||||
break
|
break
|
||||||
|
|
||||||
video.release()
|
video.release()
|
||||||
return np.stack(frames)
|
metadata.frames_indices = indices
|
||||||
|
return np.stack(frames), metadata
|
||||||
|
|
||||||
|
|
||||||
def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
def read_video_decord(
|
||||||
|
video_path: str,
|
||||||
|
sample_indices_fn: Optional[Callable] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Decode the video with Decord decoder.
|
Decode a video using the Decord backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
If not specified and `fps==None`, all frames are sampled.
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
fps (`int`, *optional*):
|
If not provided, simple uniform sampling with fps is performed.
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
Example:
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
def sample_indices_fn(metadata, **kwargs):
|
||||||
|
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||||
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
|
- `VideoMetadata` object.
|
||||||
"""
|
"""
|
||||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||||
video_fps = vr.get_avg_fps()
|
video_fps = vr.get_avg_fps()
|
||||||
total_num_frames = len(vr)
|
total_num_frames = len(vr)
|
||||||
if num_frames is None and fps is not None:
|
duration = total_num_frames / video_fps if video_fps else 0
|
||||||
num_frames = int(total_num_frames / video_fps * fps)
|
metadata = VideoMetadata(
|
||||||
if num_frames > total_num_frames:
|
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
|
||||||
raise ValueError(
|
)
|
||||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
|
||||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||||
)
|
|
||||||
indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames)
|
|
||||||
frames = vr.get_batch(indices).asnumpy()
|
frames = vr.get_batch(indices).asnumpy()
|
||||||
return frames
|
metadata.frames_indices = indices
|
||||||
|
return frames, metadata
|
||||||
|
|
||||||
|
|
||||||
def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
def read_video_pyav(
|
||||||
|
video_path: str,
|
||||||
|
sample_indices_fn: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Decode the video with PyAV decoder.
|
Decode the video with PyAV decoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
If not specified and `fps==None`, all frames are sampled.
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
fps (`int`, *optional*):
|
If not provided, simple uniform sampling with fps is performed.
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
Example:
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
def sample_indices_fn(metadata, **kwargs):
|
||||||
|
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||||
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
|
- `VideoMetadata` object.
|
||||||
"""
|
"""
|
||||||
container = av.open(video_path)
|
container = av.open(video_path)
|
||||||
|
|
||||||
total_num_frames = container.streams.video[0].frames
|
total_num_frames = container.streams.video[0].frames
|
||||||
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||||
if num_frames is None and fps is not None:
|
duration = total_num_frames / video_fps if video_fps else 0
|
||||||
num_frames = int(total_num_frames / video_fps * fps)
|
metadata = VideoMetadata(
|
||||||
if num_frames > total_num_frames:
|
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
|
||||||
raise ValueError(
|
)
|
||||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
|
||||||
)
|
|
||||||
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
|
|
||||||
|
|
||||||
frames = []
|
frames = []
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
@@ -677,48 +717,58 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Opti
|
|||||||
break
|
break
|
||||||
if i >= 0 and i in indices:
|
if i >= 0 and i in indices:
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
|
||||||
|
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||||
|
metadata.frames_indices = indices
|
||||||
|
return video, metadata
|
||||||
|
|
||||||
|
|
||||||
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
|
def read_video_torchvision(
|
||||||
|
video_path: str,
|
||||||
|
sample_indices_fn: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Decode the video with torchvision decoder.
|
Decode the video with torchvision decoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path (`str`):
|
video_path (`str`):
|
||||||
Path to the video file.
|
Path to the video file.
|
||||||
num_frames (`int`, *optional*):
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
Number of frames to sample uniformly. Should be passed only when `fps=None`.
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
If not specified and `fps==None`, all frames are sampled.
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
fps (`int`, *optional*):
|
If not provided, simple uniform sampling with fps is performed.
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
Example:
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
def sample_indices_fn(metadata, **kwargs):
|
||||||
|
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
|
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||||
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
|
- `VideoMetadata` object.
|
||||||
"""
|
"""
|
||||||
video, _, info = torchvision_io.read_video(
|
video, _, info = torchvision_io.read_video(
|
||||||
video_path,
|
video_path,
|
||||||
start_pts=0.0,
|
start_pts=0.0,
|
||||||
end_pts=None,
|
end_pts=None,
|
||||||
pts_unit="sec",
|
pts_unit="sec",
|
||||||
output_format="TCHW",
|
output_format="THWC",
|
||||||
)
|
)
|
||||||
video_fps = info["video_fps"]
|
video_fps = info["video_fps"]
|
||||||
total_num_frames = video.size(0) - 1
|
total_num_frames = video.size(0)
|
||||||
if num_frames is None and fps is not None:
|
duration = total_num_frames / video_fps if video_fps else 0
|
||||||
num_frames = int(total_num_frames / video_fps * fps)
|
metadata = VideoMetadata(
|
||||||
if num_frames > total_num_frames:
|
total_num_frames=int(total_num_frames),
|
||||||
raise ValueError(
|
fps=float(video_fps),
|
||||||
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
|
duration=float(duration),
|
||||||
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
|
video_backend="torchvision",
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_frames is not None:
|
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||||
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
|
|
||||||
return video[idx]
|
|
||||||
|
|
||||||
return video
|
video = video[indices].contiguous().numpy()
|
||||||
|
metadata.frames_indices = indices
|
||||||
|
return video, metadata
|
||||||
|
|
||||||
|
|
||||||
VIDEO_DECODERS = {
|
VIDEO_DECODERS = {
|
||||||
@@ -734,6 +784,8 @@ def load_video(
|
|||||||
num_frames: Optional[int] = None,
|
num_frames: Optional[int] = None,
|
||||||
fps: Optional[int] = None,
|
fps: Optional[int] = None,
|
||||||
backend: str = "opencv",
|
backend: str = "opencv",
|
||||||
|
sample_indices_fn: Optional[Callable] = None,
|
||||||
|
**kwargs,
|
||||||
) -> np.array:
|
) -> np.array:
|
||||||
"""
|
"""
|
||||||
Loads `video` to a numpy array.
|
Loads `video` to a numpy array.
|
||||||
@@ -748,13 +800,36 @@ def load_video(
|
|||||||
If not specified and `num_frames==None`, all frames are sampled.
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
backend (`str`, *optional*, defaults to `"opencv"`):
|
||||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
||||||
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
|
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||||
|
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||||
|
indices at which the video should be sampled. For example:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def sample_indices_fn(metadata, **kwargs):
|
||||||
|
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`np.array`: A numpy array of shape (num_frames, channels, height, width).
|
Tuple[`np.array`, Dict]: A tuple containing:
|
||||||
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
|
- Metadata dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if fps is not None and num_frames is not None:
|
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
|
||||||
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
if fps is not None and num_frames is not None and sample_indices_fn is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If user didn't pass a sampling function, create one on the fly with default logic
|
||||||
|
if sample_indices_fn is None:
|
||||||
|
|
||||||
|
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||||
|
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
|
||||||
|
|
||||||
|
sample_indices_fn = sample_indices_fn_func
|
||||||
|
|
||||||
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
|
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
|
||||||
if not is_yt_dlp_available():
|
if not is_yt_dlp_available():
|
||||||
@@ -796,8 +871,8 @@ def load_video(
|
|||||||
)
|
)
|
||||||
|
|
||||||
video_decoder = VIDEO_DECODERS[backend]
|
video_decoder = VIDEO_DECODERS[backend]
|
||||||
video = video_decoder(file_obj, num_frames=num_frames, fps=fps)
|
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
|
||||||
return video
|
return video, metadata
|
||||||
|
|
||||||
|
|
||||||
def load_images(
|
def load_images(
|
||||||
|
|||||||
@@ -24,13 +24,21 @@ import sys
|
|||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import typing_extensions
|
import typing_extensions
|
||||||
|
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image, load_video
|
from .image_utils import (
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
VideoInput,
|
||||||
|
is_valid_image,
|
||||||
|
is_vision_available,
|
||||||
|
load_image,
|
||||||
|
load_video,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -339,14 +347,10 @@ class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, Comm
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateKwargs(TypedDict, total=False):
|
class TokenizerChatTemplateKwargs(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
Keyword arguments for processor chat templates.
|
Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor.
|
||||||
|
|
||||||
tokenize (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to tokenize the output or not.
|
|
||||||
return_dict (`bool`, defaults to `False`):
|
|
||||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
|
||||||
tools (`List[Dict]`, *optional*):
|
tools (`List[Dict]`, *optional*):
|
||||||
A list of tools (callable functions) that will be accessible to the model. If the template does not
|
A list of tools (callable functions) that will be accessible to the model. If the template does not
|
||||||
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
|
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
|
||||||
@@ -373,6 +377,23 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
|||||||
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
|
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
|
||||||
the mask will contain 1. For user and system tokens, the mask will contain 0.
|
the mask will contain 1. For user and system tokens, the mask will contain 0.
|
||||||
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tools: Optional[List[Dict]] = None
|
||||||
|
documents: Optional[List[Dict[str, str]]] = None
|
||||||
|
add_generation_prompt: Optional[bool] = False
|
||||||
|
continue_final_message: Optional[bool] = False
|
||||||
|
return_assistant_tokens_mask: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for processor chat templates.
|
||||||
|
|
||||||
|
tokenize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tokenize the output or not.
|
||||||
|
return_dict (`bool`, defaults to `False`):
|
||||||
|
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||||
num_frames (`int`, *optional*):
|
num_frames (`int`, *optional*):
|
||||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||||
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
|
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
|
||||||
@@ -382,22 +403,28 @@ class ChatTemplateKwargs(TypedDict, total=False):
|
|||||||
video_fps (`int`, *optional*):
|
video_fps (`int`, *optional*):
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
|
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||||
|
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||||
|
indices at which the video should be sampled. For example:
|
||||||
|
|
||||||
|
def sample_indices_fn(num_frames, fps, metadata, **kwargs):
|
||||||
|
# add you sampling logic here ...
|
||||||
|
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenize: Optional[bool] = False
|
tokenize: Optional[bool] = False
|
||||||
return_dict: Optional[bool] = False
|
return_dict: Optional[bool] = False
|
||||||
tools: Optional[List[Dict]] = None
|
|
||||||
documents: Optional[List[Dict[str, str]]] = None
|
|
||||||
add_generation_prompt: Optional[bool] = False
|
|
||||||
continue_final_message: Optional[bool] = False
|
|
||||||
return_assistant_tokens_mask: Optional[bool] = False
|
|
||||||
num_frames: Optional[int] = None
|
num_frames: Optional[int] = None
|
||||||
video_load_backend: Optional[str] = "pyav"
|
video_load_backend: Optional[str] = "pyav"
|
||||||
video_fps: Optional[int] = None
|
video_fps: Optional[int] = None
|
||||||
|
sample_indices_fn: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
class AllKwargsForChatTemplate(
|
class AllKwargsForChatTemplate(
|
||||||
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ChatTemplateKwargs
|
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
|
||||||
): ...
|
): ...
|
||||||
|
|
||||||
|
|
||||||
@@ -1165,9 +1192,43 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
)
|
)
|
||||||
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
|
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
|
||||||
|
|
||||||
|
def _process_messages_for_chat_template(
|
||||||
|
self,
|
||||||
|
conversation: List[List[Dict[str, str]]],
|
||||||
|
batch_images: List[ImageInput],
|
||||||
|
batch_videos: List[VideoInput],
|
||||||
|
batch_video_metadata: List[List[Dict[str, any]]],
|
||||||
|
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
|
||||||
|
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
|
||||||
|
were sampled. This information cannot be accessed before the video is loaded.
|
||||||
|
|
||||||
|
For most models it is a no-op, and must be overriden by model processors which require special processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation (`List[Dict, str, str]`):
|
||||||
|
The conversation to process. Always comes in batched format.
|
||||||
|
batch_images (`List[List[ImageInput]]`):
|
||||||
|
Batch of images that were loaded from url/path defined in the conversation. The images
|
||||||
|
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
|
||||||
|
per batch.
|
||||||
|
batch_videos (`List[List[ImageInput]]`):
|
||||||
|
Batch of videos that were loaded from url/path defined in the conversation. The videos
|
||||||
|
are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays
|
||||||
|
per batch.
|
||||||
|
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
|
||||||
|
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
|
||||||
|
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
|
||||||
|
per batch.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return conversation
|
||||||
|
|
||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
self,
|
self,
|
||||||
conversation: Union[List[Dict[str, str]]],
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
**kwargs: Unpack[AllKwargsForChatTemplate],
|
**kwargs: Unpack[AllKwargsForChatTemplate],
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -1190,7 +1251,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
]
|
]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation (`List[Dict, str, str]`):
|
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
|
||||||
The conversation to format.
|
The conversation to format.
|
||||||
chat_template (`Optional[str]`, *optional*):
|
chat_template (`Optional[str]`, *optional*):
|
||||||
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||||
@@ -1207,39 +1268,42 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
|
||||||
|
# and for multimodal chat template
|
||||||
|
tokenizer_template_kwargs = {}
|
||||||
|
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
|
||||||
|
tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
|
||||||
|
value = kwargs.pop(tokenizer_key, tokenizer_value)
|
||||||
|
tokenizer_template_kwargs[tokenizer_key] = value
|
||||||
|
|
||||||
chat_template_kwargs = {}
|
chat_template_kwargs = {}
|
||||||
for key in ChatTemplateKwargs.__annotations__.keys():
|
for key in ProcessorChatTemplateKwargs.__annotations__.keys():
|
||||||
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key))
|
processor_value = getattr(ProcessorChatTemplateKwargs, key, None)
|
||||||
|
value = kwargs.pop(key, processor_value)
|
||||||
chat_template_kwargs[key] = value
|
chat_template_kwargs[key] = value
|
||||||
|
|
||||||
# Pop kwargs that should not be used by tokenizer's `apply_chat_template`
|
|
||||||
tokenize = chat_template_kwargs.pop("tokenize")
|
|
||||||
return_dict = chat_template_kwargs.pop("return_dict")
|
|
||||||
num_frames = chat_template_kwargs.pop("num_frames")
|
|
||||||
video_fps = chat_template_kwargs.pop("video_fps")
|
|
||||||
video_load_backend = chat_template_kwargs.pop("video_load_backend")
|
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
chat_template=chat_template,
|
|
||||||
tokenize=False,
|
|
||||||
return_dict=False,
|
|
||||||
**chat_template_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(conversation, (list, tuple)) and (
|
if isinstance(conversation, (list, tuple)) and (
|
||||||
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
|
||||||
):
|
):
|
||||||
conversations = conversation
|
|
||||||
is_batched = True
|
is_batched = True
|
||||||
|
conversations = conversation
|
||||||
else:
|
else:
|
||||||
conversations = [conversation]
|
|
||||||
is_batched = False
|
is_batched = False
|
||||||
|
conversations = [conversation]
|
||||||
|
|
||||||
|
num_frames = chat_template_kwargs.get("num_frames")
|
||||||
|
video_fps = chat_template_kwargs.get("video_fps")
|
||||||
|
video_load_backend = chat_template_kwargs.get("video_load_backend")
|
||||||
|
tokenize = chat_template_kwargs.get("tokenize")
|
||||||
|
return_dict = chat_template_kwargs.get("return_dict")
|
||||||
|
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
|
||||||
|
|
||||||
if tokenize:
|
if tokenize:
|
||||||
batch_images, batch_videos = [], []
|
batch_images, batch_videos = [], []
|
||||||
|
batch_video_metadata = []
|
||||||
for conversation in conversations:
|
for conversation in conversations:
|
||||||
images, videos = [], []
|
images, videos = [], []
|
||||||
|
video_metadata = []
|
||||||
for message in conversation:
|
for message in conversation:
|
||||||
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
|
||||||
image_fnames = [
|
image_fnames = [
|
||||||
@@ -1261,17 +1325,51 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
video = [np.array(load_image(image_fname)).T for image_fname in fname]
|
video = [np.array(load_image(image_fname)).T for image_fname in fname]
|
||||||
# create a 4D video because `load_video` always returns a 4D array
|
# create a 4D video because `load_video` always returns a 4D array
|
||||||
video = np.stack(video)
|
video = np.stack(video)
|
||||||
|
metadata = None
|
||||||
|
logger.warning(
|
||||||
|
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
|
||||||
|
"If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
video = load_video(fname, num_frames=num_frames, fps=video_fps, backend=video_load_backend)
|
video, metadata = load_video(
|
||||||
|
fname,
|
||||||
|
num_frames=num_frames,
|
||||||
|
fps=video_fps,
|
||||||
|
backend=video_load_backend,
|
||||||
|
sample_indices_fn=sample_indices_fn,
|
||||||
|
)
|
||||||
videos.append(video)
|
videos.append(video)
|
||||||
|
video_metadata.append(metadata)
|
||||||
|
|
||||||
# Currently all processors can accept accept nested list of batches, but not flat list of visuals
|
# Currently all processors can accept nested list of batches, but not flat list of visuals
|
||||||
# So we'll make a batched list of images and let the processor handle it
|
# So we'll make a batched list of images and let the processor handle it
|
||||||
if images:
|
if images:
|
||||||
batch_images.append(images)
|
batch_images.append(images)
|
||||||
if videos:
|
if videos:
|
||||||
batch_videos.append(videos)
|
batch_videos.append(videos)
|
||||||
|
batch_video_metadata.append(video_metadata)
|
||||||
|
|
||||||
|
# Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
|
||||||
|
conversations = self._process_messages_for_chat_template(
|
||||||
|
conversations,
|
||||||
|
batch_images=batch_images,
|
||||||
|
batch_videos=batch_videos,
|
||||||
|
batch_video_metadata=batch_video_metadata,
|
||||||
|
**chat_template_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
|
conversations,
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=False,
|
||||||
|
return_dict=False,
|
||||||
|
**tokenizer_template_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
prompt = prompt[0]
|
||||||
|
|
||||||
|
if tokenize:
|
||||||
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
|
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
|
||||||
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
|
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
|
||||||
# and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling
|
# and pass it to the processor. Users thus never worried about special tokens relying on processor hadnling
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from transformers.models.auto.processing_auto import processor_class_from_name
|
from transformers.models.auto.processing_auto import processor_class_from_name
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
@@ -538,7 +539,7 @@ class ProcessorTesterMixin:
|
|||||||
|
|
||||||
def test_chat_template_save_loading(self):
|
def test_chat_template_save_loading(self):
|
||||||
processor = self.get_processor()
|
processor = self.get_processor()
|
||||||
signature = inspect.signature(processor.__call__)
|
signature = inspect.signature(processor.__init__)
|
||||||
if "chat_template" not in {*signature.parameters.keys()}:
|
if "chat_template" not in {*signature.parameters.keys()}:
|
||||||
self.skipTest("Processor doesn't accept chat templates at input")
|
self.skipTest("Processor doesn't accept chat templates at input")
|
||||||
|
|
||||||
@@ -858,3 +859,119 @@ class ProcessorTesterMixin:
|
|||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_chat_template_video_custom_sampling(self):
|
||||||
|
"""
|
||||||
|
Tests that models can pass their custom callables to sample video indices.
|
||||||
|
"""
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
video_file_path = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"path": video_file_path,
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def dummmy_sample_indices_fn(metadata, **fn_kwargs):
|
||||||
|
# sample only the first two frame always
|
||||||
|
return [0, 1]
|
||||||
|
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
sample_indices_fn=dummmy_sample_indices_fn,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_chat_template_video_special_processing(self):
|
||||||
|
"""
|
||||||
|
Tests that models can use their own preprocessing to preprocess conversations.
|
||||||
|
"""
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
video_file_path = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video", "path": video_file_path},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _process_messages_for_chat_template(
|
||||||
|
conversation,
|
||||||
|
batch_images,
|
||||||
|
batch_videos,
|
||||||
|
batch_video_metadata,
|
||||||
|
**chat_template_kwargs,
|
||||||
|
):
|
||||||
|
# Let us just always return a dummy prompt
|
||||||
|
new_msg = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
||||||
|
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return new_msg
|
||||||
|
|
||||||
|
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
|
||||||
|
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
||||||
|
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
||||||
|
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
|
||||||
|
|||||||
Reference in New Issue
Block a user