Add support for Pixtral (#33449)

* initial commit

* gloups

* updates

* work

* weights match

* nits

* nits

* updates to support the tokenizer :)

* updates

* Pixtral processor (#33454)

* rough outline

* Add in image break and end tokens

* Fix

* Udo some formatting changes

* Set patch_size default

* Fix

* Fix token expansion

* nit in conversion script

* Fix image token list creation

* done

* add expected results

* Process list of list of images (#33465)

* updates

* working image and processor

* this is the expected format

* some fixes

* push current updated

* working mult images!

* add a small integration test

* Uodate configuration docstring

* Formatting

* Config docstring fix

* simplify model test

* fixup modeling and etests

* Return BatchMixFeature in image processor

* fix some copies

* update

* nits

* Update model docstring

* Apply suggestions from code review

* Fix up

* updates

* revert modeling changes

* update

* update

* fix load safe

* addd liscence

* update

* use pixel_values as required by the model

* skip some tests and refactor

* Add pixtral image processing tests (#33476)

* Image processing tests

* Add processing tests

* woops

* defaults reflect pixtral image processor

* fixup post merge

* images -> pixel values

* oups sorry Mr docbuilder

* isort

* fix

* fix processor tests

* small fixes

* nit

* update

* last nits

* oups this was really breaking!

* nits

* is composition needs to be true

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur
2024-09-14 12:28:39 +02:00
committed by GitHub
parent 7bb1c99800
commit 8bd2b1e8c2
24 changed files with 2707 additions and 2 deletions

View File

@@ -862,6 +862,8 @@
title: Perceiver title: Perceiver
- local: model_doc/pix2struct - local: model_doc/pix2struct
title: Pix2Struct title: Pix2Struct
- local: model_doc/pixtral
title: Pixtral
- local: model_doc/sam - local: model_doc/sam
title: Segment Anything title: Segment Anything
- local: model_doc/siglip - local: model_doc/siglip

View File

@@ -253,6 +253,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ | | [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ | | [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ | | [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
| [Pixtral](model_doc/pixtral) | ❌ | ❌ | ❌ |
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ | | [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ | | [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ | | [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |

View File

@@ -0,0 +1,98 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Pixtral
## Overview
The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!
Tips:
- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- The format for one or mulitple prompts is the following:
```
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
```
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)
Here is an example of how to run it:
```python
from transformers import LlavaForConditionalGeneration, AutoProcessor
from PIL import Image
model_id = "hf-internal-testing/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cuda")
processor = AutoProcessor.from_pretrained(model_id)
IMG_URLS = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300",
"https://picsum.photos/id/27/500/500",
"https://picsum.photos/id/17/150/600",
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
EXPECTED_GENERATION = """
Describe the images.
Sure, let's break down each image description:
1. **Image 1:**
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
2. **Image 2:**
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
3. **Image 3:**
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
4. **Image 4:**
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
"""
```
## PixtralVisionConfig
[[autodoc]] PixtralVisionConfig
## PixtralModel
[[autodoc]] PixtralModel
- forward
## PixtralImageProcessor
[[autodoc]] PixtralImageProcessor
- preprocess
## PixtralProcessor
[[autodoc]] PixtralProcessor

View File

@@ -649,6 +649,7 @@ _import_structure = {
"Pix2StructTextConfig", "Pix2StructTextConfig",
"Pix2StructVisionConfig", "Pix2StructVisionConfig",
], ],
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
"models.plbart": ["PLBartConfig"], "models.plbart": ["PLBartConfig"],
"models.poolformer": ["PoolFormerConfig"], "models.poolformer": ["PoolFormerConfig"],
"models.pop2piano": ["Pop2PianoConfig"], "models.pop2piano": ["Pop2PianoConfig"],
@@ -1199,6 +1200,7 @@ else:
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"]) _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.pixtral"].append("PixtralImageProcessor")
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
@@ -1359,7 +1361,6 @@ else:
"AlignVisionModel", "AlignVisionModel",
] ]
) )
_import_structure["models.altclip"].extend( _import_structure["models.altclip"].extend(
[ [
"AltCLIPModel", "AltCLIPModel",
@@ -2977,6 +2978,7 @@ else:
"Pix2StructVisionModel", "Pix2StructVisionModel",
] ]
) )
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
_import_structure["models.plbart"].extend( _import_structure["models.plbart"].extend(
[ [
"PLBartForCausalLM", "PLBartForCausalLM",
@@ -5434,6 +5436,10 @@ if TYPE_CHECKING:
Pix2StructTextConfig, Pix2StructTextConfig,
Pix2StructVisionConfig, Pix2StructVisionConfig,
) )
from .models.pixtral import (
PixtralProcessor,
PixtralVisionConfig,
)
from .models.plbart import PLBartConfig from .models.plbart import PLBartConfig
from .models.poolformer import ( from .models.poolformer import (
PoolFormerConfig, PoolFormerConfig,
@@ -6009,6 +6015,7 @@ if TYPE_CHECKING:
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
from .models.pix2struct import Pix2StructImageProcessor from .models.pix2struct import Pix2StructImageProcessor
from .models.pixtral import PixtralImageProcessor
from .models.poolformer import ( from .models.poolformer import (
PoolFormerFeatureExtractor, PoolFormerFeatureExtractor,
PoolFormerImageProcessor, PoolFormerImageProcessor,
@@ -7448,6 +7455,10 @@ if TYPE_CHECKING:
Pix2StructTextModel, Pix2StructTextModel,
Pix2StructVisionModel, Pix2StructVisionModel,
) )
from .models.pixtral import (
PixtralModel,
PixtralPreTrainedModel,
)
from .models.plbart import ( from .models.plbart import (
PLBartForCausalLM, PLBartForCausalLM,
PLBartForConditionalGeneration, PLBartForConditionalGeneration,

View File

@@ -187,6 +187,7 @@ from . import (
phi3, phi3,
phobert, phobert,
pix2struct, pix2struct,
pixtral,
plbart, plbart,
poolformer, poolformer,
pop2piano, pop2piano,

View File

@@ -205,6 +205,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("phi", "PhiConfig"), ("phi", "PhiConfig"),
("phi3", "Phi3Config"), ("phi3", "Phi3Config"),
("pix2struct", "Pix2StructConfig"), ("pix2struct", "Pix2StructConfig"),
("pixtral", "PixtralVisionConfig"),
("plbart", "PLBartConfig"), ("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"), ("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"), ("pop2piano", "Pop2PianoConfig"),
@@ -509,6 +510,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("phi3", "Phi3"), ("phi3", "Phi3"),
("phobert", "PhoBERT"), ("phobert", "PhoBERT"),
("pix2struct", "Pix2Struct"), ("pix2struct", "Pix2Struct"),
("pixtral", "Pixtral"),
("plbart", "PLBart"), ("plbart", "PLBart"),
("poolformer", "PoolFormer"), ("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"), ("pop2piano", "Pop2Piano"),

View File

@@ -114,6 +114,7 @@ else:
("owlvit", ("OwlViTImageProcessor",)), ("owlvit", ("OwlViTImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)), ("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)), ("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
("poolformer", ("PoolFormerImageProcessor",)), ("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)), ("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)),

View File

@@ -193,6 +193,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("persimmon", "PersimmonModel"), ("persimmon", "PersimmonModel"),
("phi", "PhiModel"), ("phi", "PhiModel"),
("phi3", "Phi3Model"), ("phi3", "Phi3Model"),
("pixtral", "PixtralModel"),
("plbart", "PLBartModel"), ("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"), ("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"), ("prophetnet", "ProphetNetModel"),

View File

@@ -82,6 +82,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("owlvit", "OwlViTProcessor"), ("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"), ("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"), ("pix2struct", "Pix2StructProcessor"),
("pixtral", "PixtralProcessor"),
("pop2piano", "Pop2PianoProcessor"), ("pop2piano", "Pop2PianoProcessor"),
("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_audio", "Qwen2AudioProcessor"),
("qwen2_vl", "Qwen2VLProcessor"), ("qwen2_vl", "Qwen2VLProcessor"),

View File

@@ -385,6 +385,7 @@ else:
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)), ("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)), ("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),

View File

@@ -73,7 +73,7 @@ class LlavaConfig(PretrainedConfig):
```""" ```"""
model_type = "llava" model_type = "llava"
is_composition = False is_composition = True
def __init__( def __init__(
self, self,

View File

@@ -0,0 +1,70 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_pixtral": ["PixtralVisionConfig"],
"processing_pixtral": ["PixtralProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_pixtral"] = [
"PixtralModel",
"PixtralPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
if TYPE_CHECKING:
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_pixtral import (
PixtralModel,
PixtralPreTrainedModel,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_pixtral import PixtralImageProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@@ -0,0 +1,103 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pixtral model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class PixtralVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an
Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Pixtral-9B.
e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of input channels in the input images.
image_size (`int`, *optional*, defaults to 1024):
Max dimension of the input images.
patch_size (`int`, *optional*, defaults to 16):
Size of the image patches.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
Activation function used in the hidden layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for the attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings with the input embeddings.
Example:
```python
>>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
>>> # Initializing a Pixtral 12B style configuration
>>> config = PixtralVisionConfig()
>>> # Initializing a model from the pixtral 12B style configuration
>>> model = PixtralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "pixtral"
def __init__(
self,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
image_size=1024,
patch_size=16,
hidden_act="gelu",
attention_dropout=0.0,
rope_theta=10000.0,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.rope_theta = rope_theta
self.tie_word_embeddings = tie_word_embeddings
self.head_dim = hidden_size // num_attention_heads

View File

@@ -0,0 +1,285 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import regex as re
import torch
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from safetensors.torch import load_file as safe_load_file
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
from tokenizers.models import BPE
from transformers import (
LlavaConfig,
LlavaForConditionalGeneration,
MistralConfig,
PixtralImageProcessor,
PixtralProcessor,
PixtralVisionConfig,
PreTrainedTokenizerFast,
)
from transformers.convert_slow_tokenizer import bytes_to_unicode
"""
# Here is how to get the original tokens!
model_name = "mistralai/Pixtral-12B-2409"
tok = MistralTokenizer.from_model(model_name)
from mistral_common.protocol.instruct.request import ChatCompletionRequest, UserMessage, ImageChunk, TextChunk
EXPECTED_TOKENS = tok.encode_chat_completion(
ChatCompletionRequest(
messages=[
UserMessage(
content=[
TextChunk(text="Describe the images"),
] + [ImageChunk(image=img) for img in IMG_URLS]
)
],
model="pixtral",
)
)
assert tokenizer.decode(inputs["input_ids"][0]) == EXPECTED_TOKENS
"""
OLD_KEY_TO_NEW_KEY_MAPPING = {
# Layer Normalization Weights
r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight",
r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight",
# Self Attention Projections
r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight",
r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight",
r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight",
r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight",
# MLP Projections
r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
# Additional mappings
r"vision_encoder": r"vision_tower",
r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1",
r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2",
r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight",
r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight",
r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight",
r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight",
r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
r"output.weight": r"language_model.lm_head.weight",
r"norm.weight": r"language_model.model.norm.weight",
}
class MistralConverter:
"""
A general tiktoken converter.
"""
def __init__(
self,
vocab=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
*args,
**kwargs,
):
super().__init__(*args)
self.vocab = vocab
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = additional_special_tokens
def extract_vocab_merges_from_model(self, vocab: str):
bpe_ranks = vocab
byte_encoder = bytes_to_unicode()
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
merges = []
vocab = {}
for idx, (token, rank) in enumerate(bpe_ranks.items()):
if token not in self.additional_special_tokens:
vocab[token_bytes_to_string(token)] = idx
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
merges.extend(local)
else:
vocab[token] = idx
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges
def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.add_special_tokens(self.additional_special_tokens)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
def convert_mistral_tokenizer():
model_name = "mistralai/Pixtral-12B-2409"
tokenizer = MistralTokenizer.from_model(model_name)
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
all_special = [
token.value if hasattr(token, "value") else token
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
]
specials_tokens = {token: all_special.index(token) for token in all_special}
specials_tokens.update(vocab)
vocab = specials_tokens
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
bos_token="<s>",
unk_token="<unk>",
eos_token="</s>",
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
return tokenizer
def permute_for_rope(value, n_heads, config):
dim1 = value.shape[0]
dim2 = config.hidden_size
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def convert_dictionnary(original_state_dict, vision_config, text_config):
new_dict = {}
all_keys = "\n" + "\n".join(original_state_dict.keys())
old_keys = all_keys
for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items():
all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys)
OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n")))
for key, value in original_state_dict.items():
new_key = OLD_TO_NEW[key]
if "vision_encoder" in key:
_config = vision_config
num_attention_heads = _config.num_attention_heads
else:
_config = text_config
if "q_proj" in new_key:
num_attention_heads = _config.num_attention_heads
if "k_proj" in new_key:
num_attention_heads = _config.num_key_value_heads
# convert the text model (basically mistral model)
if "q_proj" in new_key or "k_proj" in new_key:
value = permute_for_rope(value, num_attention_heads, _config)
new_dict[new_key] = value
return new_dict
def convert_mistral_model(input_dir, output_dir):
text_config = MistralConfig(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2,
head_dim=128,
hidden_act="silu",
hidden_size=5120,
initializer_range=0.02,
intermediate_size=14336,
max_position_embeddings=1024000,
model_type="mistral",
num_attention_heads=32,
num_hidden_layers=40,
num_key_value_heads=8,
rms_norm_eps=1e-05,
rope_theta=1000000000.0,
sliding_window=None,
tie_word_embeddings=False,
vocab_size=131072,
)
vision_config = PixtralVisionConfig()
config = LlavaConfig(
vision_config,
text_config,
vision_feature_layer=-1,
image_token_index=10,
vision_feature_select_strategy="full",
image_seq_length=1,
)
config.architectures = ["LlavaForConditionalGeneration"]
config.save_pretrained(output_dir)
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
with torch.device("meta"):
model = LlavaForConditionalGeneration(config)
model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir)
tokenizer = convert_mistral_tokenizer()
image_processor = PixtralImageProcessor()
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
processor.save_pretrained(output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
convert_mistral_model(args.input_dir, args.output_dir)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,519 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Pixtral."""
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
resize,
to_channel_dimension_format,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
from ...utils.import_utils import requires_backends
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class BatchMixFeature(BatchFeature):
def to(self, *args, **kwargs) -> "BatchMixFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
return self
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Convert a single image or a list of images to a list of numpy arrays.
Args:
images (`ImageInput`):
A single image or a list of images.
Returns:
A list of numpy arrays.
"""
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
images = [[images]]
# If it's a list of images, it's a single batch, so convert it to a list of lists
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
images = [images]
# If it's a list of batches, it's already in the right format
elif (
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and is_valid_image(images[0][0])
):
pass
else:
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)
return images
# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
def convert_to_rgb(image: ImageInput) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (Image):
The image to convert.
"""
requires_backends(convert_to_rgb, ["vision"])
if not isinstance(image, PIL.Image.Image):
return image
if image.mode == "RGB":
return image
# First we convert to RGBA to set background to white.
image = image.convert("RGBA")
# Create a new image with a white background.
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
new_image.paste(image, (0, 0), image)
new_image = new_image.convert("RGB")
return new_image
def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int:
"""
Calculate the number of image tokens given the image size and patch size.
Args:
image_size (`Tuple[int, int]`):
The size of the image as `(height, width)`.
patch_size (`Tuple[int, int]`):
The patch size as `(height, width)`.
Returns:
`int`: The number of image tokens.
"""
height, width = image_size
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
num_width_tokens = (width - 1) // patch_width + 1
num_height_tokens = (height - 1) // patch_height + 1
return num_height_tokens, num_width_tokens
def get_resize_output_image_size(
input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple:
"""
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
size.
Args:
input_image (`np.ndarray`):
The image to resize.
size (`int` or `Tuple[int, int]`):
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
patch_size (`int` or `Tuple[int, int]`):
The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)`
will be used
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`tuple`: The target (height, width) dimension of the output image after resizing.
"""
max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size)
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
height, width = get_image_size(input_image, input_data_format)
ratio = max(height / max_height, width / max_width)
if ratio > 1:
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
return num_height_tokens * patch_height, num_width_tokens * patch_width
# Hack to get tensor conversion used in BatchFeature without batching the images
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
return BatchFeature()._get_is_as_tensor_fns(tensor_type)
def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
if is_tensor(array):
return array
return as_tensor(array)
class PixtralImageProcessor(BaseImageProcessor):
r"""
Constructs a Pixtral image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
patch_size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"longest_edge": 1024}
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
patch_size = get_size_dict(patch_size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.patch_size = patch_size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"patch_size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
patch_size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dict containing the longest possible edge of the image.
patch_size (`Dict[str, int]`):
Patch size used to calculate the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
if "longest_edge" in size:
size = (size["longest_edge"], size["longest_edge"])
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
if "height" in patch_size and "width" in patch_size:
patch_size = (patch_size["height"], patch_size["width"])
else:
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
patch_size=patch_size,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
patch_size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Describes the maximum input dimensions to the model.
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
Patch size in the model. Used to calculate the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
patch_size = patch_size if patch_size is not None else self.patch_size
patch_size = get_size_dict(patch_size, default_to_square=True)
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
images_list = make_list_of_images(images)
if not valid_images(images_list[0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
batch_images = []
batch_image_sizes = []
for sample_images in images_list:
images = []
image_sizes = []
for image in sample_images:
if do_resize:
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
resample=resample,
input_data_format=input_data_format,
)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
images.append(image)
image_sizes.append(get_image_size(image, input_data_format))
batch_images.append(images)
batch_image_sizes.append(image_sizes)
images_list = [
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
for images in batch_images
]
# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)

View File

@@ -0,0 +1,517 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Pixtral model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_pixtral import PixtralVisionConfig
logger = logging.get_logger(__name__)
def position_ids_in_meshgrid(patch_embeds_list, max_width):
positions = []
for patch in patch_embeds_list:
height, width = patch.shape[-2:]
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
ids = h_grid * max_width + v_grid
positions.append(ids[:, 0])
return torch.cat(positions)
class PixtralRotaryEmbedding(nn.Module):
"""
The key with pixtral embedding is just that you have a frequency for each pixel positions.
If you have height x width pixels (or embedding pixels)
then the frequency used for ROPE is given by indexing the pre_computed frequency on the
width and height.
What you output is of dimension batch, height * width, dim with dim the embed dim.
This simply means that for each image hidden states, you are going to add
a corresponding positional embedding, based on it's index in the grid.
"""
def __init__(self, config, device):
super().__init__()
self.rope_type = "default"
self.dim = config.head_dim
self.base = config.rope_theta
max_patches_per_side = config.image_size // config.patch_size
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
h = torch.arange(max_patches_per_side, device=freqs.device)
w = torch.arange(max_patches_per_side, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
inv_freq = torch.cat(
[
freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
],
dim=-1,
).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# TODO maybe make it torch compatible later on. We can also just slice
self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
freqs = self.inv_freq[position_ids]
# position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
emb = freqs
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PixtralAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, patches, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
class PixtralMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
class PixtralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
PixtralRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class PixtralAttentionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
self.feed_forward = PixtralMLP(config)
self.attention = PixtralAttention(config)
self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
hidden_states, attn_weights = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.ffn_norm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class PixtralTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(PixtralAttentionLayer(config))
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions
)
PIXTRAL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
PIXTRAL_START_DOCSTRING,
)
class PixtralPreTrainedModel(PreTrainedModel):
config_class = PixtralVisionConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PixtralVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of Pixtral isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
PIXTRAL_INPUTS_DOCSTRING = r"""
Args:
pixel_values: list of N_img images of variable sizes,
each of shape (C, H, W)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
def generate_block_attention_mask(patch_embeds_list, tensor):
dtype = tensor.dtype
device = tensor.device
seq_len = tensor.shape[1]
d_min = torch.finfo(dtype).min
causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
for start, end in zip(block_start_idx, block_end_idx):
causal_mask[start:end, start:end] = 0
causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
return causal_mask
@add_start_docstrings(
"""The PIXTRAL model which consists of a vision backbone and a language model.""",
PIXTRAL_START_DOCSTRING,
)
class PixtralModel(PixtralPreTrainedModel):
base_model_prefix = "vision_encoder"
def __init__(self, config):
super().__init__(config)
self.config = config
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralTransformer(config)
self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device)
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: List[torch.Tensor],
output_hidden_states: Optional[bool] = False,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args,
**kwargs,
) -> Union[Tuple, BaseModelOutput]:
"""
Returns:
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
).to(self.device)
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)
return self.transformer(patch_embeds, attention_mask, position_embedding)

View File

@@ -0,0 +1,282 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Pixtral.
"""
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
logger = logging.get_logger(__name__)
# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
class BatchMixFeature(BatchFeature):
def to(self, *args, **kwargs) -> "BatchMixFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
return self
class PixtralProcessor(ProcessorMixin):
r"""
Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor.
[`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information.
Args:
image_processor ([`PixtralImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
patch_size (`int`, *optional*, defaults to 16):
Patch size from the vision tower.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"[IMG]"`):
Special token used to denote image location.
image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`):
Special token used to denote the end of a line of pixels in an image.
image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`):
Special token used to denote the end of an image input.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"patch_size",
"image_token",
"image_break_token",
"image_end_token",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
patch_size: int = 16,
chat_template=None,
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
image_break_token="[IMG_BREAK]",
image_end_token="[IMG_END]",
**kwargs,
):
self.patch_size = patch_size
self.image_token = image_token
self.image_break_token = image_break_token
self.image_end_token = image_end_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchMixFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is not None:
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
raise ValueError(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
else:
image_inputs = {}
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
# Replace the image token with the expanded image token sequence
images = image_inputs["pixel_values"]
image_sizes = image_inputs.pop("image_sizes")
prompt_strings = []
for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
replace_strings = []
# First calculate the number of tokens needed for each image and put in a placeholder
for image, image_size in zip(sample_images, sample_image_sizes):
height, width = image_size
num_height_tokens = height // self.patch_size
num_width_tokens = width // self.patch_size
replace_tokens = [
[self.image_token] * num_width_tokens + [self.image_break_token]
] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = self.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
sample = sample.replace(self.image_token, "<placeholder>", 1)
while "<placeholder>" in sample:
replace_str = replace_strings.pop(0)
sample = sample.replace("<placeholder>", replace_str, 1)
prompt_strings.append(sample)
text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
return BatchMixFeature(data={**text_inputs, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

View File

@@ -7067,6 +7067,20 @@ class Pix2StructVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PixtralModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PixtralPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PLBartForCausalLM(metaclass=DummyObject): class PLBartForCausalLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@@ -506,6 +506,13 @@ class Pix2StructImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
class PixtralImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class PoolFormerFeatureExtractor(metaclass=DummyObject): class PoolFormerFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"] _backends = ["vision"]

View File

@@ -569,3 +569,50 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
# check that both inputs are handled correctly and generate the same output # check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_pixtral(self):
model_id = "hf-internal-testing/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
IMG_URLS = [
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw),
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
# image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# fmt: off
EXPECTED_GENERATION = """
Describe the images.
Sure, let's break down each image description:
1. **Image 1:**
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
2. **Image 2:**
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
3. **Image 3:**
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
4. **Image 4:**
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
"""
# fmt: on
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(ouptut, EXPECTED_GENERATION)

View File

View File

@@ -0,0 +1,217 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import PixtralImageProcessor
class PixtralImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
max_num_images_per_sample=3,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
patch_size=None,
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
size = size if size is not None else {"longest_edge": 24}
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.max_num_images_per_sample = max_num_images_per_sample
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.patch_size = patch_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"patch_size": self.patch_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
}
def expected_output_image_shape(self, image):
if isinstance(image, Image.Image):
width, height = image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[:2]
elif isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
max_height = max_width = self.size.get("longest_edge")
ratio = max(height / max_height, width / max_width)
if ratio > 1:
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
num_height_tokens = (height - 1) // patch_height + 1
num_width_tokens = (width - 1) // patch_width + 1
height = num_height_tokens * patch_height
width = num_width_tokens * patch_width
return self.num_channels, height, width
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
# Use prepare_image_inputs to make a list of list of single images
images_list = []
for _ in range(self.batch_size):
images = []
for _ in range(random.randint(1, self.max_num_images_per_sample)):
img = prepare_image_inputs(
batch_size=1,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)[0]
images.append(img)
images_list.append(images)
return images_list
@require_torch
@require_vision
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = PixtralImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = PixtralImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "patch_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_numpy_4_channels(self):
pass

View File

@@ -0,0 +1,292 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Pixtral model."""
import gc
import unittest
import requests
from transformers import (
AutoProcessor,
PixtralModel,
PixtralVisionConfig,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
if is_torch_available():
import torch
else:
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
class PixtralModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
projection_dim=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.scope = scope
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return PixtralVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values):
model = PixtralModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_with_projection(self, config, pixel_values):
model = PixtralModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `PixtralModel`.
"""
all_model_classes = (PixtralModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = PixtralModelTester(self)
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
@unittest.skip("model does not support input embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("model does not support input embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Not supported yet")
def test_attention_outputs(self):
pass
@unittest.skip(reason="Not supported yet")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Not supported yet")
def test_batching_equivalence(self):
pass
@unittest.skip(reason="Not supported yet")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Not supported yet")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Not supported yet")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_parallelism(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_outputs_equivalence(self):
pass
@unittest.skip(reason="Not supported yet")
def test_save_load(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Not supported yet")
def test_resize_tokens_embeddings(self):
pass
@unittest.skip(reason="Not supported yet")
def test_model_main_input_name(self):
pass
@unittest.skip(reason="Not supported yet")
def test_initialization(self):
pass
@unittest.skip(reason="Not supported yet")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Not supported yet")
def test_gradient_checkpointing_backward_compatibility(self):
pass
@unittest.skip(reason="Not supported yet")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="Not supported yet")
def test_disk_offload_safetensors(self):
pass
@unittest.skip(reason="Not supported yet")
def test_determinism(self):
pass
@require_torch
class PixtralModelIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

View File

@@ -0,0 +1,233 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import requests
import torch
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
if is_vision_available():
from PIL import Image
from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor
@require_vision
class PixtralProcessorTest(unittest.TestCase):
processor_class = PixtralProcessor
@classmethod
def setUpClass(cls):
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
def setUp(self):
super().setUp()
# FIXME - just load the processor directly from the checkpoint
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
image_processor = PixtralImageProcessor()
self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
@unittest.skip("No chat template was set for this model (yet)")
def test_chat_template(self):
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
@unittest.skip("No chat template was set for this model (yet)")
def test_image_token_filling(self):
# Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526
image_token_index = 32000
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = self.processor(
text=[self.processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)
def test_processor_with_single_image(self):
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
# Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt")
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list)
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 1)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
# fmt: off
input_ids = inputs_image["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in a url
inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt")
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_url["pixel_values"], list)
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
self.assertIsInstance(inputs_url["pixel_values"][0], list)
self.assertTrue(len(inputs_url["pixel_values"][0]) == 1)
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
# fmt: off
input_ids = inputs_url["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_with_multiple_images_single_list(self):
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
# Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list)
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
# fmt: off
input_ids = inputs_image["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in a url
inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_url["pixel_values"], list)
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
self.assertIsInstance(inputs_url["pixel_values"][0], list)
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
# fmt: off
input_ids = inputs_url["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_with_multiple_images_multiple_lists(self):
prompt_string = [
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
]
self.processor.tokenizer.pad_token = "</s>"
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
# Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 2)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list)
self.assertTrue(len(inputs_image["pixel_values"]) == 2)
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
# fmt: off
input_ids = inputs_image["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in a url
inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 2)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_url["pixel_values"], list)
self.assertTrue(len(inputs_url["pixel_values"]) == 2)
self.assertIsInstance(inputs_url["pixel_values"][0], list)
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
# fmt: off
input_ids = inputs_url["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on