Pixtral: vectorize patch embeddings and enable tests (#35122)

* initial POC

* - batch mix feature

* fix tests

* fix tests

* make style

* do not skip and instead fix tests

* update

* return back the test

* correct text with the correct ckpt
This commit is contained in:
Raushan Turganbay
2025-01-30 12:40:18 +01:00
committed by GitHub
parent 8bc4c89ee9
commit 9725e5be2f
10 changed files with 422 additions and 545 deletions

View File

@@ -280,6 +280,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]], vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str, vision_feature_select_strategy: str,
**kwargs,
): ):
""" """
Obtains image last hidden states from the vision tower and apply multimodal projection. Obtains image last hidden states from the vision tower and apply multimodal projection.
@@ -300,8 +301,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
if vision_feature_select_strategy not in ["default", "full"]: if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states, # If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them # otherwise, select the hidden states of each feature layer and concatenate them
@@ -422,6 +424,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]: ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r""" r"""
Args: Args:
@@ -492,6 +495,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
pixel_values=pixel_values, pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer, vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
) )
n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()

View File

@@ -15,12 +15,13 @@
"""Image processor class for Pixtral.""" """Image processor class for Pixtral."""
import math import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
pad,
resize, resize,
to_channel_dimension_format, to_channel_dimension_format,
) )
@@ -31,13 +32,13 @@ from ...image_utils import (
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
is_scaled_image, is_scaled_image,
is_valid_image, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_kwargs, validate_kwargs,
validate_preprocess_arguments, validate_preprocess_arguments,
) )
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging from ...utils import TensorType, is_vision_available, logging
from ...utils.import_utils import requires_backends from ...utils.import_utils import requires_backends
@@ -48,91 +49,6 @@ if is_vision_available():
import PIL import PIL
class BatchMixFeature(BatchFeature):
def to(self, *args, **kwargs) -> "BatchMixFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Convert a single image or a list of images to a list of numpy arrays.
Args:
images (`ImageInput`):
A single image or a list of images.
Returns:
A list of numpy arrays.
"""
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
images = [[images]]
# If it's a list of images, it's a single batch, so convert it to a list of lists
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
images = [images]
# If it's a list of batches, it's already in the right format
elif (
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and len(images[0]) > 0
and is_valid_image(images[0][0])
):
pass
else:
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)
return images
# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. # Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
def convert_to_rgb(image: ImageInput) -> ImageInput: def convert_to_rgb(image: ImageInput) -> ImageInput:
""" """
@@ -219,18 +135,6 @@ def get_resize_output_image_size(
return num_height_tokens * patch_height, num_width_tokens * patch_width return num_height_tokens * patch_height, num_width_tokens * patch_width
# Hack to get tensor conversion used in BatchFeature without batching the images
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
return BatchFeature()._get_is_as_tensor_fns(tensor_type)
def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
if is_tensor(array):
return array
return as_tensor(array)
class PixtralImageProcessor(BaseImageProcessor): class PixtralImageProcessor(BaseImageProcessor):
r""" r"""
Constructs a Pixtral image processor. Constructs a Pixtral image processor.
@@ -368,6 +272,49 @@ class PixtralImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
) )
def _pad_for_batching(
self,
pixel_values: List[np.ndarray],
image_sizes: List[List[int]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[np.ndarray]`):
An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`)
image_sizes (`List[List[int]]`):
A list of sizes for each image in `pixel_values` in (height, width) format.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
List[`np.ndarray`]: The padded images.
"""
max_shape = (
max([size[0] for size in image_sizes]),
max([size[1] for size in image_sizes]),
)
pixel_values = [
pad(
image,
padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])),
data_format=data_format,
input_data_format=input_data_format,
)
for image, size in zip(pixel_values, image_sizes)
]
return pixel_values
def preprocess( def preprocess(
self, self,
images: ImageInput, images: ImageInput,
@@ -449,9 +396,9 @@ class PixtralImageProcessor(BaseImageProcessor):
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
images_list = make_list_of_images(images) images = make_list_of_images(images)
if not valid_images(images_list[0]): if not valid_images(images[0]):
raise ValueError( raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
@@ -469,12 +416,12 @@ class PixtralImageProcessor(BaseImageProcessor):
) )
if do_convert_rgb: if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list] images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list] images = [to_numpy_array(image) for image in images]
if do_rescale and is_scaled_image(images_list[0][0]): if do_rescale and is_scaled_image(images[0]):
logger.warning_once( logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input" "It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
@@ -482,44 +429,43 @@ class PixtralImageProcessor(BaseImageProcessor):
if input_data_format is None: if input_data_format is None:
# We assume that all images have the same channel dimension format. # We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0]) input_data_format = infer_channel_dimension_format(images[0])
batch_images = [] batch_images = []
batch_image_sizes = [] batch_image_sizes = []
for sample_images in images_list: for image in images:
images = [] if do_resize:
image_sizes = [] image = self.resize(
for image in sample_images: image=image,
if do_resize: size=size,
image = self.resize( patch_size=patch_size,
image=image, resample=resample,
size=size, input_data_format=input_data_format,
patch_size=patch_size, )
resample=resample,
input_data_format=input_data_format,
)
if do_rescale: if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize( image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
) )
images.append(image) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
image_sizes.append(get_image_size(image, input_data_format))
batch_images.append(images)
batch_image_sizes.append(image_sizes)
images_list = [ batch_images.append(image)
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images] batch_image_sizes.append(get_image_size(image, data_format))
for images in batch_images
]
# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes pixel_values = self._pad_for_batching(
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list] pixel_values=batch_images,
return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None) image_sizes=batch_image_sizes,
input_data_format=data_format,
data_format=data_format,
)
return BatchFeature(
data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
)
__all__ = ["PixtralImageProcessor"] __all__ = ["PixtralImageProcessor"]

View File

@@ -16,7 +16,7 @@
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from ...image_processing_utils import get_size_dict from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_utils import ( from ...image_utils import (
ChannelDimension, ChannelDimension,
@@ -26,6 +26,7 @@ from ...image_utils import (
get_image_size, get_image_size,
get_image_type, get_image_type,
infer_channel_dimension_format, infer_channel_dimension_format,
make_list_of_images,
validate_fast_preprocess_arguments, validate_fast_preprocess_arguments,
validate_kwargs, validate_kwargs,
) )
@@ -38,10 +39,8 @@ from ...utils import (
logging, logging,
) )
from .image_processing_pixtral import ( from .image_processing_pixtral import (
BatchMixFeature,
convert_to_rgb, convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
make_list_of_images,
) )
@@ -189,6 +188,36 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
**kwargs, **kwargs,
) )
# Adapted from transformers.models.pixtral.image_processing_pixtral.PixtralImageProcessor._pad_for_batching
def _pad_for_batching(
self,
pixel_values: List[torch.Tensor],
image_sizes: List[List[int]],
):
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`)
image_sizes (`List[List[int]]`):
A list of sizes for each image in `pixel_values` in (height, width) format.
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_shape = (
max([size[0] for size in image_sizes]),
max([size[1] for size in image_sizes]),
)
pixel_values = [
torch.nn.functional.pad(
image,
pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0]),
)
for image, size in zip(pixel_values, image_sizes)
]
return torch.stack(pixel_values)
def preprocess( def preprocess(
self, self,
images: ImageInput, images: ImageInput,
@@ -206,7 +235,7 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchMixFeature: ) -> BatchFeature:
""" """
Preprocess an image or batch of images. Preprocess an image or batch of images.
@@ -271,8 +300,8 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
images_list = make_list_of_images(images) images = make_list_of_images(images)
image_type = get_image_type(images_list[0][0]) image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}") raise ValueError(f"Unsupported input image type {image_type}")
@@ -290,65 +319,63 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
data_format=data_format, data_format=data_format,
) )
if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
if image_type == ImageType.PIL:
images_list = [[F.pil_to_tensor(image) for image in images] for images in images_list]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images_list = [[torch.from_numpy(image).contiguous() for image in images] for images in images_list]
if device is not None:
images_list = [[image.to(device) for image in images] for images in images_list]
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0])
if input_data_format == ChannelDimension.LAST:
images_list = [[image.permute(2, 0, 1).contiguous() for image in images] for images in images_list]
input_data_format = ChannelDimension.FIRST
if do_rescale and do_normalize: if do_rescale and do_normalize:
# fused rescale and normalize # fused rescale and normalize
new_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor) new_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
new_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor) new_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
batch_images = [] batch_images = []
batch_image_sizes = [] batch_image_sizes = []
for sample_images in images_list: for image in images:
images = [] if do_convert_rgb:
image_sizes = [] image = convert_to_rgb(image)
for image in sample_images:
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
interpolation=interpolation,
)
if do_rescale and do_normalize: if image_type == ImageType.PIL:
# fused rescale and normalize image = F.pil_to_tensor(image)
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) elif image_type == ImageType.NUMPY:
elif do_rescale: # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = image * rescale_factor image = torch.from_numpy(image).contiguous()
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
images.append(image) # We assume that all images have the same channel dimension format.
image_sizes.append(get_image_size(image, input_data_format)) if input_data_format is None:
batch_images.append(images) input_data_format = infer_channel_dimension_format(image)
batch_image_sizes.append(image_sizes)
return BatchMixFeature( if input_data_format == ChannelDimension.LAST:
data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, image = image.permute(2, 0, 1).contiguous()
tensor_type=None,
image = image.to(device)
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
interpolation=interpolation,
)
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
batch_images.append(image)
batch_image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
pixel_values = self._pad_for_batching(
pixel_values=batch_images,
image_sizes=batch_image_sizes,
)
return BatchFeature(
data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
) )

View File

@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch Pixtral model.""" """PyTorch Pixtral model."""
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -57,7 +57,7 @@ class PixtralRotaryEmbedding(nn.Module):
a corresponding positional embedding, based on its index in the grid. a corresponding positional embedding, based on its index in the grid.
""" """
def __init__(self, config, device): def __init__(self, config, device=None):
super().__init__() super().__init__()
self.rope_type = "default" self.rope_type = "default"
self.dim = config.head_dim self.dim = config.head_dim
@@ -89,7 +89,6 @@ class PixtralRotaryEmbedding(nn.Module):
# Core RoPE block # Core RoPE block
freqs = self.inv_freq[position_ids] freqs = self.inv_freq[position_ids]
# position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
@@ -175,7 +174,7 @@ class PixtralAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@@ -261,8 +260,8 @@ class PixtralAttentionLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_embeddings: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]: ) -> Tuple[torch.FloatTensor]:
""" """
Args: Args:
@@ -310,7 +309,7 @@ class PixtralTransformer(nn.Module):
self, self,
inputs_embeds, inputs_embeds,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@@ -375,7 +374,7 @@ class PixtralTransformer(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )
@@ -399,10 +398,9 @@ PIXTRAL_START_DOCSTRING = r"""
class PixtralPreTrainedModel(PreTrainedModel): class PixtralPreTrainedModel(PreTrainedModel):
config_class = PixtralVisionConfig config_class = PixtralVisionConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["PixtralVisionAttention"] _no_split_modules = ["PixtralAttentionLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = ( std = (
@@ -426,6 +424,8 @@ PIXTRAL_INPUTS_DOCSTRING = r"""
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`] Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
for details. for details.
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
The sizes of the images in the batch, being (height, width) for each image.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. tensors for more detail.
@@ -470,15 +470,22 @@ class PixtralVisionModel(PixtralPreTrainedModel):
stride=config.patch_size, stride=config.patch_size,
bias=False, bias=False,
) )
self.patch_size = config.patch_size
self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5) self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralTransformer(config) self.transformer = PixtralTransformer(config)
self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device) self.patch_positional_embedding = PixtralRotaryEmbedding(config)
self.post_init()
def get_input_embeddings(self):
return self.patch_conv
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
pixel_values: List[torch.Tensor], pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = False, image_sizes: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
*args, *args,
@@ -490,24 +497,36 @@ class PixtralVisionModel(PixtralPreTrainedModel):
all tokens of all images of shape (N_toks, D) all tokens of all images of shape (N_toks, D)
""" """
# pass images through initial convolution independently # pass images through initial convolution independently
if len(pixel_values) > 1: patch_embeds = self.patch_conv(pixel_values)
raise ValueError("Batching/padding not supported yet!") patch_embeds_list = [
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample] embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)]
for embed, size in zip(patch_embeds, image_sizes)
]
# flatten to a single sequence # flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
patch_embeds = self.ln_pre(patch_embeds) patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings # positional embeddings
position_ids = position_ids_in_meshgrid( position_ids = position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
).to(self.device) )
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
attention_mask = generate_block_attention_mask( attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
) )
return self.transformer(patch_embeds, attention_mask, position_embedding)
out = self.transformer(
patch_embeds,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
return out
__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"] __all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]

View File

@@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -50,58 +50,6 @@ def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem) return is_url(elem) or is_valid_image(elem)
# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
class BatchMixFeature(BatchFeature):
def to(self, *args, **kwargs) -> "BatchMixFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self
class PixtralProcessor(ProcessorMixin): class PixtralProcessor(ProcessorMixin):
r""" r"""
Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor. Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor.
@@ -161,7 +109,7 @@ class PixtralProcessor(ProcessorMixin):
audio=None, audio=None,
videos=None, videos=None,
**kwargs: Unpack[PixtralProcessorKwargs], **kwargs: Unpack[PixtralProcessorKwargs],
) -> BatchMixFeature: ) -> BatchFeature:
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
@@ -205,28 +153,16 @@ class PixtralProcessor(ProcessorMixin):
if images is not None: if images is not None:
if is_image_or_image_url(images): if is_image_or_image_url(images):
if isinstance(text, str) or isinstance(text, list) and len(text) == 1: images = [images]
# If there's a single sample, the image must belong to it
images = [[images]]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and is_image_or_image_url(images[0]): elif isinstance(images, list) and is_image_or_image_url(images[0]):
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
# If there's a single sample, all images must belong to it
images = [images]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
pass pass
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
images = [image for sublist in images for image in sublist]
else: else:
raise ValueError( raise ValueError(
"Invalid input images. Please provide a single image, a list of images, or a list of lists of images." "Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
) )
images = [[load_image(im) for im in sample] for sample in images] images = [load_image(im) if isinstance(im, str) else im for im in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"]) image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
else: else:
image_inputs = {} image_inputs = {}
@@ -240,15 +176,13 @@ class PixtralProcessor(ProcessorMixin):
prompt_strings = text prompt_strings = text
if image_inputs.get("pixel_values") is not None: if image_inputs.get("pixel_values") is not None:
# Replace the image token with the expanded image token sequence # Replace the image token with the expanded image token sequence
images = image_inputs["pixel_values"] image_sizes = iter(image_inputs["image_sizes"])
image_sizes = image_inputs.pop("image_sizes")
prompt_strings = [] prompt_strings = []
replace_strings = []
for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text): for sample in text:
replace_strings = [] while self.image_token in sample:
# First calculate the number of tokens needed for each image and put in a placeholder height, width = next(image_sizes)
for image, image_size in zip(sample_images, sample_image_sizes):
height, width = image_size
num_height_tokens = height // self.patch_size num_height_tokens = height // self.patch_size
num_width_tokens = width // self.patch_size num_width_tokens = width // self.patch_size
replace_tokens = [ replace_tokens = [
@@ -267,7 +201,9 @@ class PixtralProcessor(ProcessorMixin):
prompt_strings.append(sample) prompt_strings.append(sample)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchMixFeature(data={**text_inputs, **image_inputs}) return BatchFeature(
data={**text_inputs, **image_inputs}, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]
)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):

View File

@@ -564,9 +564,8 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT) self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT)
@slow @slow
@require_bitsandbytes
def test_pixtral(self): def test_pixtral(self):
model_id = "hf-internal-testing/pixtral-12b" model_id = "mistral-community/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id) model = LlavaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id)
@@ -579,33 +578,75 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
# image = Image.open(requests.get(url, stream=True).raw) # image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_new_tokens=500) generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(ouptut)
# fmt: off # fmt: off
EXPECTED_GENERATION = """ EXPECTED_GENERATION = """
Describe the images. Describe the images.
Sure, let's break down each image description: Certainly! Here are the descriptions of the images:
1. **Image 1:** 1. **Image 1**: This image features a black dog with a glossy coat sitting on a wooden surface. The dog has a calm and attentive expression, looking directly at the camera. The wooden background has a rustic appearance with visible grain and texture.
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
2. **Image 2:** 2. **Image 2**: This image captures a breathtaking view of a mountainous landscape. The mountains are rugged and covered with patches of green vegetation. The sky above is clear, and the scene conveys a sense of tranquility and natural beauty.
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
3. **Image 3:** 3. **Image 3**: This image shows a beach scene during sunset. The waves are gently rolling onto the shore, and several people can be seen in the water, possibly surfing or swimming. The sky is painted with warm hues of orange and yellow, creating a serene and picturesque atmosphere.
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
4. **Image 4:** 4. **Image 4**: This image depicts a narrow, winding path that cuts through a lush, green landscape. On either side of the path, there is dense grass and various trees, including a prominent tree with white blossoms. The sky is clear and blue, adding to the peaceful and inviting ambiance of the scene.
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it. These descriptions provide a detailed overview of the content and atmosphere of each image.
""" """
# fmt: on # fmt: on
# check that both inputs are handled correctly and generate the same output # check that both inputs are handled correctly and generate the same output
self.assertListEqual(ouptut, EXPECTED_GENERATION) self.assertEqual(ouptut, EXPECTED_GENERATION)
@slow
@require_bitsandbytes
def test_pixtral_4bit(self):
model_id = "mistral-community/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
IMG_URLS = [
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw),
]
PROMPT = "<s>[INST][IMG][IMG]Describe the images.[/INST]"
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to(torch_device, torch.float16)
generate_ids = model.generate(**inputs, max_new_tokens=50)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
EXPECTED_GENERATION = "Describe the images.The image showcases a dog, which is prominently positioned in the center, taking up a significant portion of the frame. The dog is situated against a backdrop of a wooden surface, which spans the entire image. The dog appears to be a black Labrador" # fmt: skip
self.assertEqual(output, EXPECTED_GENERATION)
@slow
@require_bitsandbytes
def test_pixtral_batched(self):
model_id = "mistral-community/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
IMG_URLS = [
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/17/150/500", stream=True).raw),
]
PROMPT = [
"<s>[INST][IMG]What breed is the dog?[/INST]",
"<s>[INST][IMG]What is shown in this image?[/INST]",
]
inputs = processor(text=PROMPT, images=IMG_URLS, padding=True, return_tensors="pt").to(
torch_device, torch.float16
)
generate_ids = model.generate(**inputs, max_new_tokens=50)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
EXPECTED_GENERATION = [
'What breed is the dog?The dog in the image is a black Labrador Retriever.',
'What is shown in this image?The image depicts a narrow, winding dirt path surrounded by lush greenery. The path is flanked by grass and shrubs on both sides. On the left side, there are tall trees and dense foliage, while on the right side, there'
] # fmt: skip
self.assertEqual(output, EXPECTED_GENERATION)

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import time import time
import unittest import unittest
@@ -92,49 +91,47 @@ class PixtralImageProcessingTester:
"do_convert_rgb": self.do_convert_rgb, "do_convert_rgb": self.do_convert_rgb,
} }
def expected_output_image_shape(self, image): def expected_output_image_shape(self, images):
if isinstance(image, Image.Image): if not isinstance(images, (list, tuple)):
width, height = image.size images = [images]
elif isinstance(image, np.ndarray):
height, width = image.shape[:2]
elif isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
max_height = max_width = self.size.get("longest_edge") batch_size = len(images)
return_height, return_width = 0, 0
for image in images:
if isinstance(image, Image.Image):
width, height = image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[:2]
elif isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
ratio = max(height / max_height, width / max_width) max_height = max_width = self.size.get("longest_edge")
if ratio > 1:
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] ratio = max(height / max_height, width / max_width)
num_height_tokens = (height - 1) // patch_height + 1 if ratio > 1:
num_width_tokens = (width - 1) // patch_width + 1 height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
height = num_height_tokens * patch_height patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
width = num_width_tokens * patch_width num_height_tokens = (height - 1) // patch_height + 1
num_width_tokens = (width - 1) // patch_width + 1
return self.num_channels, height, width return_height = max(num_height_tokens * patch_height, return_height)
return_width = max(num_width_tokens * patch_width, return_width)
return batch_size, self.num_channels, return_height, return_width
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
# Use prepare_image_inputs to make a list of list of single images images = prepare_image_inputs(
batch_size=self.batch_size,
images_list = [] num_channels=self.num_channels,
for _ in range(self.batch_size): min_resolution=self.min_resolution,
images = [] max_resolution=self.max_resolution,
for _ in range(random.randint(1, self.max_num_images_per_sample)): equal_resolution=equal_resolution,
img = prepare_image_inputs( numpify=numpify,
batch_size=1, torchify=torchify,
num_channels=self.num_channels, )
min_resolution=self.min_resolution, return images
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)[0]
images.append(img)
images_list.append(images)
return images_list
@require_torch @require_torch
@@ -173,23 +170,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing = image_processing_class(**self.image_processor_dict) image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images # create random PIL images
image_inputs_list = self.image_processor_tester.prepare_image_inputs() image_inputs_list = self.image_processor_tester.prepare_image_inputs()
for image_inputs in image_inputs_list: for image in image_inputs_list:
for image in image_inputs: self.assertIsInstance(image, Image.Image)
self.assertIsInstance(image, Image.Image)
# Test not batched input # Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
image_inputs_list[0][0] self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
# Test batched # Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list): expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
for encoded_image, image in zip(encoded_images, images): self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
def test_call_numpy(self): def test_call_numpy(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
@@ -197,23 +189,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing = image_processing_class(**self.image_processor_dict) image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors # create random numpy tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True) image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
for image_inputs in image_inputs_list: for image in image_inputs_list:
for image in image_inputs: self.assertIsInstance(image, np.ndarray)
self.assertIsInstance(image, np.ndarray)
# Test not batched input # Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
image_inputs_list[0][0] self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
# Test batched # Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list): expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
for encoded_image, image in zip(encoded_images, images): self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
def test_call_pytorch(self): def test_call_pytorch(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
@@ -221,23 +208,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing = image_processing_class(**self.image_processor_dict) image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True) image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
for image_inputs in image_inputs_list: for image in image_inputs_list:
for image in image_inputs: self.assertIsInstance(image, torch.Tensor)
self.assertIsInstance(image, torch.Tensor)
# Test not batched input # Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
image_inputs_list[0][0] self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
# Test batched # Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list): expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
for encoded_image, image in zip(encoded_images, images): self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
@require_vision @require_vision
@require_torch @require_torch

View File

@@ -74,15 +74,17 @@ class PixtralVisionModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # in Pixtral, the seq length equals the number of patches * batch_size because the patches are flattened
num_patches = (image_size // patch_size) ** 2 self.seq_length = (image_size // patch_size) ** 2 * batch_size
self.seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
image_sizes = torch.tensor(
[[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device
)
config = self.get_config() config = self.get_config()
return config, pixel_values return config, pixel_values, image_sizes
def get_config(self): def get_config(self):
return PixtralVisionConfig( return PixtralVisionConfig(
@@ -127,8 +129,8 @@ class PixtralVisionModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs config, pixel_values, image_sizes = config_and_inputs
inputs_dict = {"pixel_values": pixel_values} inputs_dict = {"pixel_values": pixel_values, "image_sizes": image_sizes}
return config, inputs_dict return config, inputs_dict
@@ -142,113 +144,17 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = False
def setUp(self): def setUp(self):
self.model_tester = PixtralVisionModelTester(self) self.model_tester = PixtralVisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
@unittest.skip("model does not support input embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("model does not support input embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Not supported yet")
def test_attention_outputs(self):
pass
@unittest.skip(reason="Not supported yet")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Not supported yet")
def test_batching_equivalence(self):
pass
@unittest.skip(reason="Not supported yet")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Not supported yet")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Not supported yet")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_parallelism(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_outputs_equivalence(self):
pass
@unittest.skip(reason="Not supported yet")
def test_save_load(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_get_set_embeddings(self): def test_model_get_set_embeddings(self):
pass config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@unittest.skip(reason="Not supported yet") for model_class in self.all_model_classes:
def test_resize_tokens_embeddings(self): model = model_class(config)
pass self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
x = model.get_output_embeddings()
@unittest.skip(reason="Not supported yet") self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
def test_model_main_input_name(self):
pass
@unittest.skip(reason="Not supported yet")
def test_initialization(self):
pass
@unittest.skip(reason="Not supported yet")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Not supported yet")
def test_gradient_checkpointing_backward_compatibility(self):
pass
@unittest.skip(reason="Not supported yet")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="Not supported yet")
def test_disk_offload_safetensors(self):
pass
@unittest.skip(reason="Not supported yet")
def test_determinism(self):
pass

View File

@@ -14,7 +14,6 @@
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
import requests import requests
import torch import torch
@@ -28,7 +27,7 @@ from ...test_processing_common import ProcessorTesterMixin
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor from transformers import PixtralProcessor
@require_vision @require_vision
@@ -46,20 +45,15 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
processor = PixtralProcessor.from_pretrained("mistral-community/pixtral-12b")
# FIXME - just load the processor directly from the checkpoint
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
image_processor = PixtralImageProcessor()
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname) processor.save_pretrained(self.tmpdirname)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
@unittest.skip("No chat template was set for this model (yet)")
def test_chat_template(self): def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname) processor = self.processor_class.from_pretrained(self.tmpdirname)
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:" expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
messages = [ messages = [
{ {
@@ -73,13 +67,12 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt) self.assertEqual(expected_prompt, formatted_prompt)
@unittest.skip("No chat template was set for this model (yet)")
def test_image_token_filling(self): def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname) processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image # Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316)) image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526 expected_image_tokens = 640
image_token_index = 32000 image_token_index = 10
messages = [ messages = [
{ {
@@ -111,11 +104,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list) self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(len(inputs_image["pixel_values"]) == 1) self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 1)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
# fmt: off # fmt: off
input_ids = inputs_image["input_ids"] input_ids = inputs_image["input_ids"]
@@ -131,11 +121,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_url) self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_url["pixel_values"], list) self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(len(inputs_url["pixel_values"]) == 1) self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
self.assertIsInstance(inputs_url["pixel_values"][0], list)
self.assertTrue(len(inputs_url["pixel_values"][0]) == 1)
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
# fmt: off # fmt: off
input_ids = inputs_url["input_ids"] input_ids = inputs_url["input_ids"]
@@ -146,6 +133,28 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
) )
# fmt: on # fmt: on
# Test passing inputs as a single list
inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt")
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
# fmt: off
self.assertEqual(
inputs_image["input_ids"][0].tolist(),
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test as nested single list
inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt")
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
# fmt: off
self.assertEqual(
inputs_image["input_ids"][0].tolist(),
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_with_multiple_images_single_list(self): def test_processor_with_multiple_images_single_list(self):
processor = self.processor_class.from_pretrained(self.tmpdirname) processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:" prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
@@ -159,11 +168,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list) self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(len(inputs_image["pixel_values"]) == 1) self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32]))
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
# fmt: off # fmt: off
input_ids = inputs_image["input_ids"] input_ids = inputs_image["input_ids"]
@@ -179,11 +185,9 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_url) self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_url["pixel_values"], list) self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(len(inputs_url["pixel_values"]) == 1) self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32]))
self.assertIsInstance(inputs_url["pixel_values"][0], list)
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
# fmt: off # fmt: off
input_ids = inputs_url["input_ids"] input_ids = inputs_url["input_ids"]
self.assertEqual( self.assertEqual(
@@ -193,6 +197,17 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
) )
# fmt: on # fmt: on
# Test passing in as a nested list
inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt")
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32]))
# fmt: off
self.assertEqual(
inputs_url["input_ids"][0].tolist(),
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_with_multiple_images_multiple_lists(self): def test_processor_with_multiple_images_multiple_lists(self):
processor = self.processor_class.from_pretrained(self.tmpdirname) processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [ prompt_string = [
@@ -211,11 +226,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 2) self.assertTrue(len(inputs_image["input_ids"]) == 2)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list) self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(len(inputs_image["pixel_values"]) == 2) self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32]))
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
# fmt: off # fmt: off
input_ids = inputs_image["input_ids"] input_ids = inputs_image["input_ids"]
@@ -231,11 +243,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_url) self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 2) self.assertTrue(len(inputs_url["input_ids"]) == 2)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_url["pixel_values"], list) self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(len(inputs_url["pixel_values"]) == 2) self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32]))
self.assertIsInstance(inputs_url["pixel_values"][0], list)
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
# fmt: off # fmt: off
input_ids = inputs_url["input_ids"] input_ids = inputs_url["input_ids"]
@@ -246,6 +255,19 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
) )
# fmt: on # fmt: on
# Test passing as a single flat list
inputs_image = processor(
text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True
)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32]))
# fmt: off
self.assertEqual(
inputs_image["input_ids"][0].tolist(),
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_returns_full_length_batches(self): def test_processor_returns_full_length_batches(self):
# to avoid https://github.com/huggingface/transformers/issues/34204 # to avoid https://github.com/huggingface/transformers/issues/34204
processor = self.processor_class.from_pretrained(self.tmpdirname) processor = self.processor_class.from_pretrained(self.tmpdirname)
@@ -264,13 +286,3 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 5) self.assertTrue(len(inputs_image["input_ids"]) == 5)
self.assertTrue(len(inputs_image["pixel_values"]) == 5) self.assertTrue(len(inputs_image["pixel_values"]) == 5)
# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
if batch_size is None:
return super().prepare_image_inputs()
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size

View File

@@ -2991,6 +2991,10 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
model_forward_args = inspect.signature(model.forward).parameters
if "inputs_embeds" not in model_forward_args:
self.skipTest(reason="This model doesn't use `inputs_embeds`")
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder: if not self.is_encoder_decoder: