Chat template: return vectorized output in processors (#34275)

* update chat template

* style

* fix tests

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* typehints + docs

* fix tests

* remove unnecessary warnings

* forgot code style :(

* allow users to pass backend and num frames

* Update docs/source/en/chat_templating.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/processing_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* typo fix

* style

* address comments

* align with "pipeline" template

* update docs

* update docs

* unpack for all kwargs?

* wrong conflict resolution while rebasing

* tmp

* update docs

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-10 11:05:29 +01:00
committed by GitHub
parent 5f087d1335
commit e0646f3dce
12 changed files with 880 additions and 46 deletions

132
benchmark.py Normal file
View File

@@ -0,0 +1,132 @@
import os
import time
import cv2
import av
import numpy as np
from numba import jit, cuda
from decord import VideoReader, cpu, gpu
import torch
from torchvision import io
video_dir = "/raid/raushan/temp_dir/"
NUM_FRAMES = 32
# @jit(nopython=True, target_backend='cuda') # <-- If you have a cuda GPU
def process_video_cv2(video: cv2.VideoCapture, indices: np.array, length: int):
index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if index in indices:
# Channel 0:B 1:G 2:R
height, width, channel = frame.shape
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= length:
break
video.release()
return frames
def read_video_opencv(video_path, num_frames=NUM_FRAMES):
'''
Decode the video with open-cv decoder.
Args:
video_path (str): Path to the video file.
num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
'''
video = cv2.VideoCapture(video_path)
fps = int(video.get(cv2.CAP_PROP_FPS))
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
frames = process_video_cv2(video, indices, total_num_frames)
return np.stack(frames)
def read_video_decord(video_path, num_frames=NUM_FRAMES):
'''
Decode the video with Decord decoder.
Args:
video_path (str): Path to the video file.
num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
'''
vr = VideoReader(uri=video_path, ctx=cpu(0)) # you need to install from source to use gpu ctx
indices = np.arange(0, len(vr), len(vr) / num_frames).astype(int)
frames = vr.get_batch(indices).asnumpy()
return frames
def read_video_pyav(video_path, num_frames=NUM_FRAMES):
'''
Decode the video with PyAV decoder.
Args:
video_path (str): Path to the video file.
num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
'''
container = av.open(video_path)
# sample uniformly "num_frames" frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / num_frames).astype(int)
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
def read_video_torchvision(video_path, num_frames=NUM_FRAMES):
video, _, info = io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
return video[idx]
decoders = {"decord": read_video_decord, "opencv": read_video_opencv, "av": read_video_pyav, "torchvision": read_video_torchvision}
for name, fn in decoders.items():
start = time.perf_counter()
for video_file in os.listdir(video_dir):
path = f"{video_dir}/{video_file}"
output = fn(path)
end = time.perf_counter()
print(f"Time taken for {name}: {(end-start):.04f} sec")
# Time taken for decord: 475.2979 sec
# Time taken for opencv: 614.6062 sec
# Time taken for av: 1067.0860 sec
# Time taken for torchvision: 1924.0433 sec

View File

@@ -23,7 +23,7 @@ of text (as is the case with a standard language model), the model instead conti
of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text. of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text.
Much like tokenization, different models expect very different input formats for chat. This is the reason we added Much like tokenization, different models expect very different input formats for chat. This is the reason we added
**chat templates** as a feature. Chat templates are part of the tokenizer. They specify how to convert conversations, **chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
represented as lists of messages, into a single tokenizable string in the format that the model expects. represented as lists of messages, into a single tokenizable string in the format that the model expects.
Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model: Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model:
@@ -66,10 +66,12 @@ for you, allowing you to write universal code that works for any model.
## How do I use chat templates? ## How do I use chat templates?
As you can see in the example above, chat templates are easy to use. Simply build a list of messages, with `role` As you can see in the example above, chat templates are easy to use. Simply build a list of messages, with `role`
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] method. Once you do that, and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] or [`~ProcessorMixin.apply_chat_template`] method
depending on what type of model you are using. Once you do that,
you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts). to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).
## Usage with text-only LLMs
Here's an example of preparing input for `model.generate()`, using `Zephyr` again: Here's an example of preparing input for `model.generate()`, using `Zephyr` again:
```python ```python
@@ -116,6 +118,44 @@ How many helicopters can a human eat in one sitting?</s>
Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all. Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
``` ```
## Usage with multimodal LLMs
For multimodal LLMs such as [LLaVA](https://huggingface.co/llava-hf) the prompts can be formatted in a similar way. The only difference is you need to pass input images/videos as well along with the text. Each `"content"`
has to be a list containing either a text or an image/video.
Here's an example of preparing input for using `LLaVA` model:
```python
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id) # You may want to use bfloat16 and/or move to GPU here
processor = AutoProcessor.from_pretrained(model_id)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a friendly chatbot who always responds in the style of a pirate"}],
},
{
"role": "user",
"content": [
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{"type": "text", "text": "What are these?"},
],
},
]
processed_chat = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt")
print(processor.batch_decode(processed_chat["input_ids"][:, :30]))
```
This yields a string in LLaVAs expected input format with many `<image>` tokens at the end.
The `<image>` tokens are placeholders and each one will be replaced by image embeddings when the mode is run in the forward call. The `processed_chat` can be further passed into [`~GenerationMixin.generate`] to generate text.
```text
'<|im_start|>system
You are a friendly chatbot who always responds in the style of a pirate<|im_end|><|im_start|>user <image><image><image><image><image><image><image><image>'
```
Arr, 'twas easy after all! Arr, 'twas easy after all!
## Is there an automated pipeline for chat? ## Is there an automated pipeline for chat?

77
read_video.py Normal file
View File

@@ -0,0 +1,77 @@
import numpy as np
import cv2
import requests
from yt_dlp import YoutubeDL
from contextlib import redirect_stdout
from pathlib import Path
import io
import imageio.v3 as iio
url = "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4"
vid = cv2.VideoCapture(url)
# ret, frame = vid.read()
while(True):
# Capture frame-by-frame
ret, frame = vid.read()
#print cap.isOpened(), ret
if frame is not None:
pass
# print(frame.shape)
else:
break
print(vid.isOpened(), frame is not None)
buffer = io.BytesIO(requests.get(url).content)
video = buffer.getvalue()
frames = iio.imread(video, index=None)
print(frames.shape)
youtube_id = "https://www.youtube.com/watch?v=BaW_jenozKc"
ctx = {
"outtmpl": "-",
'logtostderr': True
}
buffer = io.BytesIO()
with redirect_stdout(buffer), YoutubeDL(ctx) as foo:
foo.download([youtube_id])
# Path(f"vi.mp4").write_bytes(buffer.getvalue())
video = buffer.getvalue()
print(type(video))
frames = iio.imread(video, index=None)
print(frames.shape)
import decord
file_obj = io.BytesIO(video)
container = decord.VideoReader(file_obj)
print(container[2].shape)
# print(np.frombuffer(video, dtype=np.uint8).shape)
# img_array = np.asarray(bytearray(video), dtype=np.uint8)
# im = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
import av
file_obj = io.BytesIO(video)
container = av.open(file_obj)
container.seek(0)
frames = []
for i, frame in enumerate(container.decode(video=0)):
if i > 10:
break
if i >= 0:
frames.append(frame)
out = np.stack([x.to_ndarray(format="rgb24") for x in frames])
print(out.shape)

107
run.py Normal file
View File

@@ -0,0 +1,107 @@
import av
import torch
import decord
from decord import VideoReader, cpu
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration, SiglipImageProcessor
model_id = "/raid/raushan/llava-next-video-qwen-7b"
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(0)
processor = LlavaNextVideoProcessor.from_pretrained(model_id, torch_dtype=torch.bfloat16)
img_proc = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
image = Image.open("/raid/raushan/image.png")
def load_video(video_path, max_frames_num,fps=1,force_sample=False):
vr = VideoReader(video_path)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps()/fps)
frame_idx = [i for i in range(0, len(vr), fps)]
frame_time = [i/fps for i in frame_idx]
if len(frame_idx) > max_frames_num or force_sample:
sample_fps = max_frames_num
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
print(spare_frames.shape)
return spare_frames,frame_time,video_time
def read_video_pyav(container, indices):
'''
Decode the video with PyAV decoder.
Args:
container (`av.container.input.InputContainer`): PyAV container.
indices (`List[int]`): List of frame indices to decode.
Returns:
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
'''
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
# define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video")
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# <image>Time farmes are this moments and we ahev 64 frames
# Please describe this video in detail.<|im_end|>
# <|im_start|>assistant
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "The video lasts for 19.97 seconds, and 64 frames are uniformly sampled from it. These frames are located at 0.00s,0.30s,0.60s,0.93s,1.23s,1.57s,1.87s,2.20s,2.50s,2.83s,3.13s,3.47s,3.77s,4.10s,4.40s,4.73s,5.03s,5.37s,5.67s,6.00s,6.30s,6.63s,6.93s,7.27s,7.57s,7.90s,8.20s,8.53s,8.83s,9.17s,9.47s,9.80s,10.10s,10.43s,10.73s,11.07s,11.37s,11.70s,12.00s,12.33s,12.63s,12.97s,13.27s,13.60s,13.90s,14.23s,14.53s,14.87s,15.17s,15.50s,15.80s,16.13s,16.43s,16.77s,17.07s,17.40s,17.70s,18.03s,18.33s,18.67s,18.97s,19.30s,19.60s,19.93s.Please answer the following questions related to this video.\nPlease describe this video in detail."},
{"type": "video"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<video>The video lasts for 19.97 seconds, and 64 frames are uniformly sampled from it. These frames are located at 0.00s,0.30s,0.60s,0.93s,1.23s,1.57s,1.87s,2.20s,2.50s,2.83s,3.13s,3.47s,3.77s,4.10s,4.40s,4.73s,5.03s,5.37s,5.67s,6.00s,6.30s,6.63s,6.93s,7.27s,7.57s,7.90s,8.20s,8.53s,8.83s,9.17s,9.47s,9.80s,10.10s,10.43s,10.73s,11.07s,11.37s,11.70s,12.00s,12.33s,12.63s,12.97s,13.27s,13.60s,13.90s,14.23s,14.53s,14.87s,15.17s,15.50s,15.80s,16.13s,16.43s,16.77s,17.07s,17.40s,17.70s,18.03s,18.33s,18.67s,18.97s,19.30s,19.60s,19.93s.Please answer the following questions related to this video.\nPlease describe this video in detail.<|im_end|>\n<|im_start|>assistant"
video_path = "/raid/raushan/karate.mp4" # hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
# sample uniformly 8 frames from the video, can sample more for longer videos
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 64).astype(int)
clip = read_video_pyav(container, indices)
clip, frame_time,video_time = load_video(video_path, max_frames_num=64, force_sample=True)
inputs_video = processor(text=prompt, videos=clip, return_tensors="pt").to(device=model.device, dtype=torch.bfloat16)
output = model.generate(**inputs_video, max_new_tokens=100, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))

View File

@@ -15,6 +15,7 @@
import base64 import base64
import os import os
from contextlib import redirect_stdout
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@@ -25,6 +26,9 @@ from packaging import version
from .utils import ( from .utils import (
ExplicitEnum, ExplicitEnum,
TensorType, TensorType,
is_av_available,
is_cv2_available,
is_decord_available,
is_jax_tensor, is_jax_tensor,
is_numpy_array, is_numpy_array,
is_tf_tensor, is_tf_tensor,
@@ -32,6 +36,7 @@ from .utils import (
is_torch_tensor, is_torch_tensor,
is_torchvision_available, is_torchvision_available,
is_vision_available, is_vision_available,
is_yt_dlp_available,
logging, logging,
requires_backends, requires_backends,
to_numpy, to_numpy,
@@ -56,6 +61,7 @@ if is_vision_available():
PILImageResampling = PIL.Image PILImageResampling = PIL.Image
if is_torchvision_available(): if is_torchvision_available():
from torchvision import io as torchvision_io
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
pil_torch_interpolation_mapping = { pil_torch_interpolation_mapping = {
@@ -67,6 +73,17 @@ if is_vision_available():
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS, PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
} }
if is_decord_available():
from decord import VideoReader, cpu
if is_av_available():
import av
if is_cv2_available():
import cv2
if is_yt_dlp_available():
from yt_dlp import YoutubeDL
if TYPE_CHECKING: if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
@@ -386,6 +403,204 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image return image
def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
"""
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
when loading a video.
Args:
total_num_frames (`int`):
Total number of frames that a video has.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Returns:
np.ndarray: np array of frame indices that will be sampled.
"""
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
else:
indices = np.arange(0, total_num_frames).astype(int)
return indices
def read_video_opencv(video_path: str, num_frames: Optional[int] = None):
"""
Decode the video with open-cv decoder.
Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if index in indices:
height, width, channel = frame.shape
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= total_num_frames:
break
video.release()
return np.stack(frames)
def read_video_decord(video_path: str, num_frames: Optional[int] = None):
"""
Decode the video with Decord decoder.
Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
indices = get_uniform_frame_indices(total_num_frames=len(vr), num_frames=num_frames)
frames = vr.get_batch(indices).asnumpy()
return frames
def read_video_pyav(video_path: str, num_frames: Optional[int] = None):
"""
Decode the video with PyAV decoder.
Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
container = av.open(video_path)
# sample uniformly "num_frames" frames from the video
total_num_frames = container.streams.video[0].frames
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
frames = []
container.seek(0)
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= 0 and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
def read_video_torchvision(video_path: str, num_frames: Optional[int] = None):
"""
Decode the video with torchvision decoder.
Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
video, _, info = torchvision_io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)
if num_frames is not None:
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
return video[idx]
return video
VIDEO_DECODERS = {
"decord": read_video_decord,
"opencv": read_video_opencv,
"pyav": read_video_pyav,
"torchvision": read_video_torchvision,
}
def load_video(video: Union[str, "VideoInput"], num_frames: Optional[int] = None, backend: str = "opencv") -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
Returns:
`np.array`: A numpy array of shape (num_frames, channels, height, width).
"""
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
buffer = BytesIO()
with redirect_stdout(buffer), YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video).content)
elif os.path.isfile(video):
file_obj = video
elif is_valid_image(video) or (isinstance(video, (list, tuple) and is_valid_image(video[0]))):
file_obj = None
else:
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend in ["opencv", "torchvision"]:
raise ValueError(
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
)
if file_obj is None:
return video
if (
(not is_decord_available() and backend == "decord")
or (not is_av_available() and backend == "pyav")
or (not is_cv2_available() and backend == "opencv")
or (not is_torchvision_available() and backend == "torchvision")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video = video_decoder(file_obj)
return video
def load_images( def load_images(
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]: ) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:

View File

@@ -164,7 +164,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
if videos is not None: if videos is not None:
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"]) video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
one_video = to_numpy_array(video_inputs["pixel_values_videos"][0]) one_video = to_numpy_array(video_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0], channel_dim=output_kwargs["images_kwargs"].get("data_format")) height, width = get_image_size(one_video[0], channel_dim=output_kwargs["images_kwargs"].get("data_format"))
num_frames = one_video.shape[0] # frame dim is always after batch dim num_frames = one_video.shape[0] # frame dim is always after batch dim
patches_height_width = int(math.sqrt(self.num_image_tokens)) patches_height_width = int(math.sqrt(self.num_image_tokens))

View File

@@ -30,7 +30,7 @@ import numpy as np
import typing_extensions import typing_extensions
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_valid_image, is_vision_available from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image, load_video
if is_vision_available(): if is_vision_available():
@@ -336,6 +336,64 @@ class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, Comm
} }
class ChatTemplateKwargs(TypedDict, total=False):
"""
Keyword arguments for processor chat templates.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tools (`List[Dict]`, *optional*):
A list of tools (callable functions) that will be accessible to the model. If the template does not
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
giving the name, description and argument types for the tool. See our
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
for more information.
documents (`List[Dict[str, str]]`, *optional*):
A list of dicts representing documents that will be accessible to the model if it is performing RAG
(retrieval-augmented generation). If the template does not support RAG, this argument will have no
effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
for examples of passing documents with chat templates.
add_generation_prompt (bool, *optional*):
If this is set, a prompt with the token(s) that indicate
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect.
continue_final_message (bool, *optional*):
If this is set, the chat will be formatted so that the final
message in the chat is open-ended, without any EOS tokens. The model will continue this message
rather than starting a new one. This allows you to "prefill" part of
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
return_assistant_tokens_mask (`bool`, defaults to `False`):
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
the mask will contain 1. For user and system tokens, the mask will contain 0.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
The backend to use when loading the video which will be used only when there are videos in the conversation.
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
that supports all types of sources to load from.
"""
tokenize: Optional[bool] = False
return_dict: Optional[bool] = False
tools: Optional[List[Dict]] = None
documents: Optional[List[Dict[str, str]]] = None
add_generation_prompt: Optional[bool] = False
continue_final_message: Optional[bool] = False
return_assistant_tokens_mask: Optional[bool] = False
num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav"
class AllKwargsForChatTemplate(
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ChatTemplateKwargs
): ...
class ProcessorMixin(PushToHubMixin): class ProcessorMixin(PushToHubMixin):
""" """
This is a mixin used to provide saving/loading functionality for all processor classes. This is a mixin used to provide saving/loading functionality for all processor classes.
@@ -1100,23 +1158,32 @@ class ProcessorMixin(PushToHubMixin):
self, self,
conversation: Union[List[Dict[str, str]]], conversation: Union[List[Dict[str, str]]],
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
tokenize: bool = False, **kwargs: Unpack[AllKwargsForChatTemplate],
**kwargs,
) -> str: ) -> str:
""" """
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
conversations to turn them into a single tokenizable string. conversations to turn them into a single tokenizable string.
The input is expected to be in the following format, where each message content is a list consisting of text and
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
]
Args: Args:
conversation (`List[Dict, str, str]`): conversation (`List[Dict, str, str]`):
The conversation to format. The conversation to format.
chat_template (`Optional[str]`, *optional*): chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
chat template is used. chat template is used.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
**kwargs:
Additional keyword arguments
""" """
if chat_template is None: if chat_template is None:
@@ -1128,10 +1195,62 @@ class ProcessorMixin(PushToHubMixin):
"or provide a chat template as an argument. See " "or provide a chat template as an argument. See "
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information." "https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
) )
return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs text_kwargs = {}
for key in TextKwargs.__annotations__.keys():
value = kwargs.pop(key, None)
if value is not None:
text_kwargs[key] = value
chat_template_kwargs = {}
for key in ChatTemplateKwargs.__annotations__.keys():
value = kwargs.pop(key, getattr(ChatTemplateKwargs, key))
chat_template_kwargs[key] = value
# Pop kwargs that should not be used by tokenizer's `apply_chat_template`
tokenize = chat_template_kwargs.pop("tokenize")
return_dict = chat_template_kwargs.pop("return_dict")
num_frames = chat_template_kwargs.pop("num_frames")
video_load_backend = chat_template_kwargs.pop("video_load_backend")
prompt = self.tokenizer.apply_chat_template(
conversation,
chat_template=chat_template,
tokenize=False,
return_dict=False,
**text_kwargs,
**chat_template_kwargs,
) )
# we will have to return all processed inputs in a dict
if tokenize:
images, videos = [], []
for message in conversation:
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
for vision_info in visuals:
if vision_info["type"] == "image":
for key in ["image", "url", "path", "base64"]:
if key in vision_info:
images.append(load_image(vision_info[key]))
elif vision_info["type"] == "video":
for key in ["video", "url", "path"]:
if key in vision_info:
videos.append(
load_video(vision_info[key], num_frames=num_frames, backend=video_load_backend)
)
out = self(
text=prompt,
images=images if images else None,
videos=videos if videos else None,
**kwargs,
)
if return_dict:
return out
else:
return out["input_ids"]
return prompt
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(self, generated_outputs):
""" """
Post-process the output of a vlm to decode the text. Post-process the output of a vlm to decode the text.

View File

@@ -131,6 +131,7 @@ from .import_utils import (
is_cv2_available, is_cv2_available,
is_cython_available, is_cython_available,
is_datasets_available, is_datasets_available,
is_decord_available,
is_detectron2_available, is_detectron2_available,
is_eetq_available, is_eetq_available,
is_essentia_available, is_essentia_available,
@@ -236,6 +237,7 @@ from .import_utils import (
is_uroman_available, is_uroman_available,
is_vision_available, is_vision_available,
is_vptq_available, is_vptq_available,
is_yt_dlp_available,
requires_backends, requires_backends,
torch_only_method, torch_only_method,
) )

View File

@@ -101,6 +101,7 @@ _apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm") _aqlm_available = _is_package_available("aqlm")
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) _vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
_av_available = importlib.util.find_spec("av") is not None _av_available = importlib.util.find_spec("av") is not None
_decord_available = importlib.util.find_spec("decord") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq") _eetq_available = _is_package_available("eetq")
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
@@ -113,6 +114,7 @@ _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
# `importlib.metadata.util` doesn't work with `opencv-python-headless`. # `importlib.metadata.util` doesn't work with `opencv-python-headless`.
_cv2_available = importlib.util.find_spec("cv2") is not None _cv2_available = importlib.util.find_spec("cv2") is not None
_yt_dlp_available = importlib.util.find_spec("yt_dlp") is not None
_datasets_available = _is_package_available("datasets") _datasets_available = _is_package_available("datasets")
_detectron2_available = _is_package_available("detectron2") _detectron2_available = _is_package_available("detectron2")
# We need to check both `faiss` and `faiss-cpu`. # We need to check both `faiss` and `faiss-cpu`.
@@ -313,6 +315,10 @@ def is_cv2_available():
return _cv2_available return _cv2_available
def is_yt_dlp_available():
return _yt_dlp_available
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
@@ -841,6 +847,10 @@ def is_av_available():
return _av_available return _av_available
def is_decord_available():
return _decord_available
def is_ninja_available(): def is_ninja_available():
r""" r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
@@ -1276,6 +1286,22 @@ pip install av
Please note that you may need to restart your runtime after installation. Please note that you may need to restart your runtime after installation.
""" """
# docstyle-ignore
YT_DLP_IMPORT_ERROR = """
{0} requires the YT-DLP library but it was not found in your environment. You can install it with:
```
pip install yt-dlp
```
Please note that you may need to restart your runtime after installation.
"""
DECORD_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
```
pip install decord
```
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore # docstyle-ignore
CV2_IMPORT_ERROR = """ CV2_IMPORT_ERROR = """
@@ -1616,6 +1642,7 @@ BACKENDS_MAPPING = OrderedDict(
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)), ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)), ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)), ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
@@ -1654,6 +1681,7 @@ BACKENDS_MAPPING = OrderedDict(
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)), ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
] ]
) )

View File

@@ -17,8 +17,8 @@ import tempfile
import unittest import unittest
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
from transformers.testing_utils import require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin
@@ -26,6 +26,9 @@ from ...test_processing_common import ProcessorTesterMixin
if is_vision_available(): if is_vision_available():
from transformers import CLIPImageProcessor from transformers import CLIPImageProcessor
if is_torch_available:
import torch
@require_vision @require_vision
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@@ -94,6 +97,55 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt) self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_dict(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
out_dict_with_image = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
@require_torch
def test_chat_template_dict_torch(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
def test_chat_template_with_continue_final_message(self): def test_chat_template_with_continue_final_message(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and" expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"

View File

@@ -16,8 +16,8 @@ import shutil
import tempfile import tempfile
import unittest import unittest
from transformers.testing_utils import require_vision from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin
@@ -31,6 +31,9 @@ if is_vision_available():
Qwen2TokenizerFast, Qwen2TokenizerFast,
) )
if is_torch_available:
import torch
@require_vision @require_vision
class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase): class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@@ -100,3 +103,60 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt) self.assertEqual(expected_prompt, formatted_prompt)
@require_av
def test_chat_template_dict(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = [[151644, 872, 220, 151647, 198, 3838, 374, 6839, 304, 419, 2766, 30, 151645, 151644, 77091, 198]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
out_dict_with_video = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
@require_torch
@require_av
def test_chat_template_dict_torch(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))

View File

@@ -110,32 +110,34 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
] ]
input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_ids = [ expected_ids = [
128000, # <|begin_of_text|> [
128006, # <|start_header_id|> 128000, # <|begin_of_text|>
9125, # "system" 128006, # <|start_header_id|>
128007, # <|end_of_header|> 9125, # "system"
271, # "\n\n" 128007, # <|end_of_header|>
2028, 271, # "\n\n"
374, 2028,
264, 374,
1296, 264,
11914, 1296,
13, # "This is a test sentence." 11914,
128009, # <|eot_id|> 13, # "This is a test sentence."
128006, # <|start_header_id|> 128009, # <|eot_id|>
882, # "user" 128006, # <|start_header_id|>
128007, # <|end_of_header|> 882, # "user"
271, # "\n\n" 128007, # <|end_of_header|>
2028, 271, # "\n\n"
374, 2028,
264, 374,
2077, 264,
13, # "This is a response.", 2077,
128009, # <|eot_id|> 13, # "This is a response.",
128006, # <|start_header_id|> 128009, # <|eot_id|>
78191, # "assistant" 128006, # <|start_header_id|>
128007, # <|end_of_header|> 78191, # "assistant"
271, # "\n\n" 128007, # <|end_of_header|>
271, # "\n\n"
]
] ]
self.assertEqual(input_ids, expected_ids) self.assertEqual(input_ids, expected_ids)
@@ -146,9 +148,9 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": "Describe this image in two sentences"}, {"type": "text", "text": "Describe this image in two sentences"},
{"type": "image"}, {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": " Test sentence "}, {"type": "text", "text": " Test sentence "},
{"type": "image"}, {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "ok\n"}, {"type": "text", "text": "ok\n"},
], ],
} }
@@ -164,10 +166,10 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
# fmt: off # fmt: off
expected_ids = [ expected_ids = [[
128000, 128006, 882, 128007, 271, 75885, 420, 2217, 304, 1403, 23719, 128256, 128000, 128006, 882, 128007, 271, 75885, 420, 2217, 304, 1403, 23719, 128256,
3475, 11914, 262, 128256, 564, 198, 128009, 128006, 78191, 128007, 271, 3475, 11914, 262, 128256, 564, 198, 128009, 128006, 78191, 128007, 271,
] ]]
# fmt: on # fmt: on
self.assertEqual(input_ids, expected_ids) self.assertEqual(input_ids, expected_ids)