Add Mistral3 (#36790)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* initial start

* style and dummies

* Create convert_mistral3_weights_to_hf.py

* update

* typo

* typo

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* up

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* update

* update

* Update image_processing_mistral3.py

* Update convert_mistral3_weights_to_hf.py

* fix patch merger

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* up

* update modular to fit

* style

* Update convert_mistral3_weights_to_hf.py

* typo

* Update modular_mistral3.py

* simplify a lot all shape shenanigans

* simplify

* add working test processor

* Add partially working common modeling tests

* All tests working and remove mistral3 image processors

* add docs and fixup

* fix inference with image size >1540

* 🚨fix test image proc pixtral

* Remove vision_feature_select_strategy

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* Update convert_mistral3_weights_to_hf.py

* clean

* fix test checkpoints

* Update test_modeling_mistral3.py

* Update test_modeling_mistral3.py

* style

* Use Pixtral processor

* up

* finish cleaning processor to use pixtral directly

* Update __init__.py

* Update processing_pixtral.py

* doc

* Update __init__.py

* Update mistral3.md

* Update _toctree.yml

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com>
This commit is contained in:
Cyril Vallez
2025-03-18 12:04:42 +01:00
committed by GitHub
parent bd92073692
commit e959530b8f
21 changed files with 2303 additions and 6 deletions

View File

@@ -529,6 +529,8 @@
title: MegatronGPT2 title: MegatronGPT2
- local: model_doc/mistral - local: model_doc/mistral
title: Mistral title: Mistral
- local: model_doc/mistral3
title: Mistral3
- local: model_doc/mixtral - local: model_doc/mixtral
title: Mixtral title: Mixtral
- local: model_doc/mluke - local: model_doc/mluke

View File

@@ -0,0 +1,234 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Mistral3
## Overview
Building upon Mistral Small 3 (2501), Mistral Small 3.1 (2503) adds state-of-the-art vision understanding and enhances long context capabilities up to 128k tokens without compromising text performance. With 24 billion parameters, this model achieves top-tier capabilities in both text and vision tasks.
It is ideal for:
- Fast-response conversational agents.
- Low-latency function calling.
- Subject matter experts via fine-tuning.
- Local inference for hobbyists and organizations handling sensitive data.
- Programming and math reasoning.
- Long document understanding.
- Visual understanding.
This model was contributed by [cyrilvallez](https://huggingface.co/cyrilvallez) and [yonigozlan](https://huggingface.co/yonigozlan).
The original code can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/pixtral.py) and [here](https://github.com/mistralai/mistral-common).
## Usage example
### Inference with Pipeline
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `Mistral3` models in just a few lines of code:
```python
>>> from transformers import pipeline
>>> messages = [
... {
... "role": "user",
... "content": [
... {
... "type": "image",
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
... },
... {"type": "text", "text": "Describe this image."},
... ],
... },
... ]
>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16)
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
>>> outputs[0]["generated_text"]
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
```
### Inference on a single image
This example demonstrates how to perform inference on a single image with the Mistral3 models using chat templates.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
... {"type": "text", "text": "Describe this image"},
... ],
... }
... ]
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
>>> generate_ids = model.generate(**inputs, max_new_tokens=20)
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
>>> decoded_output
"The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"...
```
### Text-only generation
This example shows how to generate text using the Mistral3 model without providing any image input.
````python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = ".mistralai/Mistral-Small-3.1-24B-Instruct-2503"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, always end your accurate response with an ASCII drawing of a cat."
>>> user_prompt = "Give me 5 non-formal ways to say 'See you later' in French."
>>> messages = [
... {"role": "system", "content": SYSTEM_PROMPT},
... {"role": "user", "content": user_prompt},
... ]
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=text, return_tensors="pt").to(0, dtype=torch.float16)
>>> generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
>>> decoded_output = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
>>> print(decoded_output)
"1. À plus tard!
2. Salut, à plus!
3. À toute!
4. À la prochaine!
5. Je me casse, à plus!
```
/\_/\
( o.o )
> ^ <
```"
````
### Batched image and text inputs
Mistral3 models also support batched image and text inputs.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
... [
... {
... "role": "user",
... "content": [
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
... {"type": "text", "text": "Write a haiku for this image"},
... ],
... },
... ],
... [
... {
... "role": "user",
... "content": [
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
... {"type": "text", "text": "Describe this image"},
... ],
... },
... ],
... ]
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
>>> output = model.generate(**inputs, max_new_tokens=25)
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
>>> decoded_outputs
["Write a haiku for this imageCalm waters reflect\nWhispers of the forest's breath\nPeace on wooden path"
, "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"]
```
### Batched multi-image input and quantization with BitsAndBytes
This implementation of the Mistral3 models supports batched text-images inputs with different number of images for each text.
This example also how to use `BitsAndBytes` to load the model in 4bit quantization.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
>>> model = AutoModelForImageTextToText.from_pretrained(
... model_checkpoint, quantization_config=quantization_config
... )
>>> messages = [
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
...                 {"type": "text", "text": "Write a haiku for this image"},
...             ],
...         },
...     ],
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
...                 {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
...                 {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
...             ],
...         },
...     ],
>>> ]
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
>>> output = model.generate(**inputs, max_new_tokens=25)
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
>>> decoded_outputs
["Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n", "These images depict two different landmarks. Can you identify them? Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."]
```
## Mistral3Config
[[autodoc]] Mistral3Config
## Mistral3ForConditionalGeneration
[[autodoc]] Mistral3ForConditionalGeneration
- forward

View File

@@ -613,6 +613,7 @@ _import_structure = {
], ],
"models.mimi": ["MimiConfig"], "models.mimi": ["MimiConfig"],
"models.mistral": ["MistralConfig"], "models.mistral": ["MistralConfig"],
"models.mistral3": ["Mistral3Config"],
"models.mixtral": ["MixtralConfig"], "models.mixtral": ["MixtralConfig"],
"models.mllama": [ "models.mllama": [
"MllamaConfig", "MllamaConfig",
@@ -2940,6 +2941,12 @@ else:
"MistralPreTrainedModel", "MistralPreTrainedModel",
] ]
) )
_import_structure["models.mistral3"].extend(
[
"Mistral3ForConditionalGeneration",
"Mistral3PreTrainedModel",
]
)
_import_structure["models.mixtral"].extend( _import_structure["models.mixtral"].extend(
[ [
"MixtralForCausalLM", "MixtralForCausalLM",
@@ -5788,6 +5795,7 @@ if TYPE_CHECKING:
MimiConfig, MimiConfig,
) )
from .models.mistral import MistralConfig from .models.mistral import MistralConfig
from .models.mistral3 import Mistral3Config
from .models.mixtral import MixtralConfig from .models.mixtral import MixtralConfig
from .models.mllama import ( from .models.mllama import (
MllamaConfig, MllamaConfig,
@@ -7844,6 +7852,10 @@ if TYPE_CHECKING:
MistralModel, MistralModel,
MistralPreTrainedModel, MistralPreTrainedModel,
) )
from .models.mistral3 import (
Mistral3ForConditionalGeneration,
Mistral3PreTrainedModel,
)
from .models.mixtral import ( from .models.mixtral import (
MixtralForCausalLM, MixtralForCausalLM,
MixtralForQuestionAnswering, MixtralForQuestionAnswering,

View File

@@ -169,6 +169,7 @@ from . import (
mgp_str, mgp_str,
mimi, mimi,
mistral, mistral,
mistral3,
mixtral, mixtral,
mllama, mllama,
mluke, mluke,

View File

@@ -192,6 +192,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("mgp-str", "MgpstrConfig"), ("mgp-str", "MgpstrConfig"),
("mimi", "MimiConfig"), ("mimi", "MimiConfig"),
("mistral", "MistralConfig"), ("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"),
("mixtral", "MixtralConfig"), ("mixtral", "MixtralConfig"),
("mllama", "MllamaConfig"), ("mllama", "MllamaConfig"),
("mobilebert", "MobileBertConfig"), ("mobilebert", "MobileBertConfig"),
@@ -537,6 +538,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mgp-str", "MGP-STR"), ("mgp-str", "MGP-STR"),
("mimi", "Mimi"), ("mimi", "Mimi"),
("mistral", "Mistral"), ("mistral", "Mistral"),
("mistral3", "Mistral3"),
("mixtral", "Mixtral"), ("mixtral", "Mixtral"),
("mllama", "Mllama"), ("mllama", "Mllama"),
("mluke", "mLUKE"), ("mluke", "mLUKE"),

View File

@@ -111,6 +111,7 @@ else:
("mask2former", ("Mask2FormerImageProcessor",)), ("mask2former", ("Mask2FormerImageProcessor",)),
("maskformer", ("MaskFormerImageProcessor",)), ("maskformer", ("MaskFormerImageProcessor",)),
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")), ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("mllama", ("MllamaImageProcessor",)), ("mllama", ("MllamaImageProcessor",)),
("mobilenet_v1", ("MobileNetV1ImageProcessor",)), ("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
("mobilenet_v2", ("MobileNetV2ImageProcessor",)), ("mobilenet_v2", ("MobileNetV2ImageProcessor",)),

View File

@@ -361,6 +361,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("mamba2", "Mamba2ForCausalLM"), ("mamba2", "Mamba2ForCausalLM"),
("mega", "MegaForMaskedLM"), ("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForPreTraining"), ("megatron-bert", "MegatronBertForPreTraining"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"),
("mobilebert", "MobileBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"),
("mpnet", "MPNetForMaskedLM"), ("mpnet", "MPNetForMaskedLM"),
@@ -802,6 +803,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"),
("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"),
@@ -839,6 +841,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("llava", "LlavaForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"),

View File

@@ -84,6 +84,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("markuplm", "MarkupLMProcessor"), ("markuplm", "MarkupLMProcessor"),
("mctct", "MCTCTProcessor"), ("mctct", "MCTCTProcessor"),
("mgp-str", "MgpstrProcessor"), ("mgp-str", "MgpstrProcessor"),
("mistral3", "PixtralProcessor"),
("mllama", "MllamaProcessor"), ("mllama", "MllamaProcessor"),
("moonshine", "Wav2Vec2Processor"), ("moonshine", "Wav2Vec2Processor"),
("oneformer", "OneFormerProcessor"), ("oneformer", "OneFormerProcessor"),

View File

@@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_mistral3 import *
from .modeling_mistral3 import *
from .processing_mistral3 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,137 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig
class Mistral3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Mistral3ForConditionalGeneration`]. It is used to instantiate an
Mistral3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `PixtralVisionConfig`):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MistralConfig`):
The config object or dictionary of the text backbone.
image_token_index (`int`, *optional*, defaults to 10):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -1):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
multimodal_projector_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the multimodal projector.
spatial_merge_size (`int`, *optional*, defaults to 2):
The downsampling factor for the spatial merge operation.
Example:
```python
>>> from transformers import Mistral3ForConditionalGeneration, Mistral3Config, PixtralVisionConfig, MistralConfig
>>> # Initializing a Pixtral-vision config
>>> vision_config = PixtralVisionConfig()
>>> # Initializing a Mistral config
>>> text_config = MistralConfig()
>>> # Initializing a Mistral3 configuration
>>> configuration = Mistral3Config(vision_config, text_config)
>>> # Initializing a model from the mistral3.1 configuration
>>> model = Mistral3ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mistral3"
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
is_composition = True
def __init__(
self,
vision_config=None,
text_config=None,
image_token_index=10,
projector_hidden_act="gelu",
vision_feature_layer=-1,
multimodal_projector_bias=False,
spatial_merge_size=2,
**kwargs,
):
super().__init__(**kwargs)
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_layer = vision_feature_layer
if isinstance(vision_config, dict):
vision_config["model_type"] = vision_config["model_type"] if "model_type" in vision_config else "pixtral"
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
vision_config = CONFIG_MAPPING["pixtral"](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=1540,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
head_dim=64,
hidden_act="gelu",
)
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["mistral"](
attention_dropout=0.0,
head_dim=128,
hidden_act="silu",
hidden_size=5120,
initializer_range=0.02,
intermediate_size=32768,
max_position_embeddings=131072,
model_type="mistral",
num_attention_heads=32,
num_hidden_layers=40,
num_key_value_heads=8,
rms_norm_eps=1e-05,
rope_theta=1000000000.0,
sliding_window=None,
use_cache=True,
vocab_size=131072,
)
self.text_config = text_config
self.multimodal_projector_bias = multimodal_projector_bias
self.spatial_merge_size = spatial_merge_size
__all__ = ["Mistral3Config"]

View File

@@ -0,0 +1,241 @@
# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import torch
from safetensors.torch import load_file
from transformers import (
Mistral3Config,
Mistral3ForConditionalGeneration,
MistralConfig,
PixtralImageProcessorFast,
PixtralProcessor,
PixtralVisionConfig,
)
from transformers.integrations.mistral import convert_tekken_tokenizer
# fmt: off
STATE_DICT_MAPPING = {
# Text model keys
r"^output.weight": r"language_model.lm_head.weight",
r"^norm.weight": r"language_model.model.norm.weight",
r"^tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
r"^layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
r"^layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
r"^layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight",
r"^layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
r"^layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
r"^layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
# Vision model keys
r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight",
r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight",
r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"vision_tower.transformer.layers.\1.attention.\2_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
r"^vision_language_adapter.w_in": r"multi_modal_projector.linear_1",
r"^vision_language_adapter.w_out": r"multi_modal_projector.linear_2",
r"^vision_encoder.ln_pre.weight": r"vision_tower.ln_pre.weight",
r"^vision_encoder.patch_conv.weight": r"vision_tower.patch_conv.weight",
r"^patch_merger.merging_layer.weight": r"multi_modal_projector.patch_merger.merging_layer.weight",
r"^pre_mm_projector_norm.weight": r"multi_modal_projector.norm.weight",
}
# fmt: on
def map_old_key_to_new(old_key):
"""Map of a key of the original state dict to the equivalent key in HF format"""
for pattern, replacement in STATE_DICT_MAPPING.items():
new_key, n_replace = re.subn(pattern, replacement, old_key)
# Early exit of the loop
if n_replace > 0:
return new_key
raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def permute_for_rope(tensor, n_heads, dim1, dim2):
"""Permute the weights for the ROPE formulation."""
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
tensor = tensor.transpose(1, 2)
tensor = tensor.reshape(dim1, dim2)
return tensor
def convert_state_dict(original_state_dict: dict, config: MistralConfig):
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
new_dict = {}
for old_key, tensor in original_state_dict.items():
new_key = map_old_key_to_new(old_key)
if "vision" in old_key:
num_attention_heads = config.vision_config.num_attention_heads
num_key_value_heads = num_attention_heads
hidden_size = config.vision_config.hidden_size
head_dim = config.vision_config.head_dim
key_value_dim = head_dim * num_attention_heads
query_dim = head_dim * num_attention_heads
else:
num_attention_heads = config.text_config.num_attention_heads
hidden_size = config.text_config.hidden_size
head_dim = config.text_config.head_dim
num_key_value_heads = config.text_config.num_key_value_heads
key_value_dim = head_dim * num_key_value_heads
query_dim = head_dim * num_attention_heads
if "q_proj" in new_key:
tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size)
elif "k_proj" in new_key:
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
new_dict[new_key] = tensor
return new_dict
def convert_config(original_config: dict, max_position_embeddings: int = 131072):
original_vision_config = original_config.pop("vision_encoder")
original_text_config = original_config
# Text config
text_key_mapping = {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
"intermediate_size": "hidden_dim",
"num_attention_heads": "n_heads",
"num_key_value_heads": "n_kv_heads",
"rms_norm_eps": "norm_eps",
}
similar_text_keys_to_keep = [
"head_dim",
"vocab_size",
"rope_theta",
]
new_text_config_kwargs = {k: original_text_config[v] for k, v in text_key_mapping.items()}
new_text_config_kwargs.update({k: v for k, v in original_text_config.items() if k in similar_text_keys_to_keep})
# These are not always defined depending on `params.json`
new_text_config_kwargs["sliding_window"] = original_text_config.get("sliding_window", None)
new_text_config_kwargs["max_position_embeddings"] = original_text_config.get(
"max_seq_len", max_position_embeddings
)
# This may sometimes be a string in `params.json`
if new_text_config_kwargs["sliding_window"] is not None:
new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"])
new_text_config = MistralConfig(**new_text_config_kwargs)
# Vision config
new_vision_config = original_vision_config
adapter_bias = new_vision_config.pop("adapter_bias", False)
_ = new_vision_config.pop("mm_projector_id", None)
_ = new_vision_config.pop("add_pre_mm_projector_layer_norm", None)
spatial_merge_size = new_vision_config.pop("spatial_merge_size")
image_token_id = new_vision_config.pop("image_token_id", 10)
_ = new_vision_config.pop("image_break_token_id", 12)
_ = new_vision_config.pop("image_end_token_id", 13)
_ = new_vision_config.pop("max_image_size")
new_vision_config = PixtralVisionConfig(**new_vision_config)
new_config = Mistral3Config(
vision_config=new_vision_config,
text_config=new_text_config,
multimodal_projector_bias=adapter_bias,
image_token_index=image_token_id,
spatial_merge_size=spatial_merge_size,
vision_feature_layer=-1,
)
return new_config
def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int):
"""Convert the model and save it (this implicitly save the config as well)."""
params = read_json(os.path.join(input_dir, "params.json"))
config = convert_config(params, max_position_embeddings)
full_state_dict = {}
# The model may be split between different files, but a single nn.Module is always fully present in a single file
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
for shard_file in shards:
original_state_dict = load_file(os.path.join(input_dir, shard_file))
new_dict = convert_state_dict(original_state_dict, config)
full_state_dict.update(new_dict)
# Load weights into model and resave them
with torch.device("meta"):
model = Mistral3ForConditionalGeneration(config)
model.load_state_dict(full_state_dict, strict=True, assign=True)
model.save_pretrained(output_dir)
def convert_and_write_processor(input_dir: str, output_dir: str):
"""Convert the tokenizer and save it."""
tokenizer_file = os.path.join(input_dir, "tekken.json")
tokenizer = convert_tekken_tokenizer(tokenizer_file)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
chat_template = '{%- if messages[0]["role"] == "system" %}{%- set system_message = messages[0]["content"] %}{%- set loop_messages = messages[1:] %}\n{%- else %}{%- set loop_messages = messages %}{%- endif %}{{- bos_token }}{%- for message in loop_messages %}{%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}{{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}{%- endif %}{%- if message["role"] == "user" %}{%- if loop.last and system_message is defined %}{{- "[INST]" + system_message + "\n\n" }}{%- else %}{{ "[INST]" }}{%- endif %}{%- endif %}{%- if message["content"] is not string %}{%- for chunk in message["content"] %}{%- if chunk["type"] == "text" %}{%- if "content" in chunk %}{{- chunk["content"] }}{%- elif "text" in chunk %}{{- chunk["text"] }}{%- endif %}{%- elif chunk["type"] == "image" %}{{- "[IMG]" }}{%- else %}{{- raise_exception("Unrecognized content type!") }}{%- endif %}{%- endfor %}{%- else %}{{- message["content"] }}{%- endif %}{%- if message["role"] == "user" %}{{- "[/INST]" }}{%- elif message["role"] == "assistant" %}{{- eos_token}}{%- else %}{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}{%- endif %}{%- endfor %}'
config = read_json(os.path.join(input_dir, "params.json"))
patch_size = config["vision_encoder"]["patch_size"]
spatial_merge_size = config["vision_encoder"]["spatial_merge_size"]
max_image_size = config["vision_encoder"]["max_image_size"]
image_processor = PixtralImageProcessorFast(patch_size=patch_size, size={"longest_edge": max_image_size})
processor = PixtralProcessor(
tokenizer=tokenizer,
image_processor=image_processor,
image_token="[IMG]",
patch_size=patch_size,
chat_template=chat_template,
spatial_merge_size=spatial_merge_size,
)
# Finally save it
processor.save_pretrained(output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir",
help="Location of Mistral weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--max_position_embeddings",
type=int,
default=131072,
help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).",
)
args = parser.parse_args()
convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings)
convert_and_write_processor(args.input_dir, args.output_dir)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,553 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_mistral3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_mistral3 import Mistral3Config
_CONFIG_FOR_DOC = "Mistral3Config"
class Mistral3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Mistral3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Mistral3PatchMerger(nn.Module):
"""
Learned merging of spatial_merge_size ** 2 patches
"""
def __init__(self, config: Mistral3Config):
super().__init__()
self.config = config
hidden_size = config.vision_config.hidden_size
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = self.config.vision_config.patch_size
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
image_sizes = [
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
]
tokens_per_image = [h * w for h, w in image_sizes]
d = image_features.shape[-1]
permuted_tensor = []
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
)
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)
image_features = torch.cat(permuted_tensor, dim=0)
image_features = self.merging_layer(image_features)
return image_features
class Mistral3MultiModalProjector(nn.Module):
def __init__(self, config: Mistral3Config):
super().__init__()
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
self.patch_merger = Mistral3PatchMerger(config)
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size * num_feature_layers,
config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
image_features = self.norm(image_features)
image_features = self.patch_merger(image_features, image_sizes)
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@dataclass
class Mistral3CausalLMOutputWithPast(ModelOutput):
"""
Base class for Mistral3 causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
MISTRAL3_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Mistral3Config`] or [`Mistral3VisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
MISTRAL3_START_DOCSTRING,
)
class Mistral3PreTrainedModel(PreTrainedModel):
config_class = Mistral3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Mistral3VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of Mistral3 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
MISTRAL3_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Mistral3Processor`] uses
[`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
MISTRAL3_START_DOCSTRING,
)
class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
def __init__(self, config: Mistral3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = Mistral3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
image_sizes: torch.Tensor,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
return image_features
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Mistral3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
return model_inputs
__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]

View File

@@ -0,0 +1,286 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...utils import is_torchdynamo_compiling, logging
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
from ..mistral.modeling_mistral import MistralRMSNorm
from .configuration_mistral3 import Mistral3Config
logger = logging.get_logger(__name__)
class Mistral3RMSNorm(MistralRMSNorm):
pass
class Mistral3PatchMerger(nn.Module):
"""
Learned merging of spatial_merge_size ** 2 patches
"""
def __init__(self, config: Mistral3Config):
super().__init__()
self.config = config
hidden_size = config.vision_config.hidden_size
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = self.config.vision_config.patch_size
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
image_sizes = [
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
]
tokens_per_image = [h * w for h, w in image_sizes]
d = image_features.shape[-1]
permuted_tensor = []
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
)
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)
image_features = torch.cat(permuted_tensor, dim=0)
image_features = self.merging_layer(image_features)
return image_features
class Mistral3MultiModalProjector(nn.Module):
def __init__(self, config: Mistral3Config):
super().__init__()
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
self.patch_merger = Mistral3PatchMerger(config)
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size * num_feature_layers,
config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
image_features = self.norm(image_features)
image_features = self.patch_merger(image_features, image_sizes)
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
image_sizes: torch.Tensor,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
return image_features
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Mistral3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
__all__ = [
"Mistral3PreTrainedModel", # noqa
"Mistral3ForConditionalGeneration",
]

View File

@@ -128,8 +128,9 @@ def get_resize_output_image_size(
if ratio > 1: if ratio > 1:
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
height = int(math.ceil(height / ratio)) # Here we use floor to ensure the image is always smaller than the given "longest_edge"
width = int(math.ceil(width / ratio)) height = int(math.floor(height / ratio))
width = int(math.floor(width / ratio))
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
return num_height_tokens * patch_height, num_width_tokens * patch_width return num_height_tokens * patch_height, num_width_tokens * patch_width

View File

@@ -64,6 +64,8 @@ class PixtralProcessor(ProcessorMixin):
The tokenizer is a required input. The tokenizer is a required input.
patch_size (`int`, *optional*, defaults to 16): patch_size (`int`, *optional*, defaults to 16):
Patch size from the vision tower. Patch size from the vision tower.
spatial_merge_size (`int`, *optional*, defaults to 1):
The downsampling factor for the spatial merge operation.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string. in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"[IMG]"`): image_token (`str`, *optional*, defaults to `"[IMG]"`):
@@ -78,6 +80,7 @@ class PixtralProcessor(ProcessorMixin):
valid_kwargs = [ valid_kwargs = [
"chat_template", "chat_template",
"patch_size", "patch_size",
"spatial_merge_size",
"image_token", "image_token",
"image_break_token", "image_break_token",
"image_end_token", "image_end_token",
@@ -90,6 +93,7 @@ class PixtralProcessor(ProcessorMixin):
image_processor=None, image_processor=None,
tokenizer=None, tokenizer=None,
patch_size: int = 16, patch_size: int = 16,
spatial_merge_size: int = 1,
chat_template=None, chat_template=None,
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
image_break_token="[IMG_BREAK]", image_break_token="[IMG_BREAK]",
@@ -97,6 +101,7 @@ class PixtralProcessor(ProcessorMixin):
**kwargs, **kwargs,
): ):
self.patch_size = patch_size self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.image_token = image_token self.image_token = image_token
self.image_break_token = image_break_token self.image_break_token = image_break_token
self.image_end_token = image_end_token self.image_end_token = image_end_token
@@ -187,8 +192,8 @@ class PixtralProcessor(ProcessorMixin):
for sample in text: for sample in text:
while self.image_token in sample: while self.image_token in sample:
height, width = next(image_sizes) height, width = next(image_sizes)
num_height_tokens = height // self.patch_size num_height_tokens = height // (self.patch_size * self.spatial_merge_size)
num_width_tokens = width // self.patch_size num_width_tokens = width // (self.patch_size * self.spatial_merge_size)
replace_tokens = [ replace_tokens = [
[self.image_token] * num_width_tokens + [self.image_break_token] [self.image_token] * num_width_tokens + [self.image_break_token]
] * num_height_tokens ] * num_height_tokens

View File

@@ -6392,6 +6392,20 @@ class MistralPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Mistral3ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Mistral3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralForCausalLM(metaclass=DummyObject): class MixtralForCausalLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@@ -125,6 +125,7 @@ VLM_CLASS_NAMES = [
"qwen2_5_vl", "qwen2_5_vl",
"ayavision", "ayavision",
"gemma3", "gemma3",
"mistral3",
] ]

View File

View File

@@ -0,0 +1,482 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch GotOcr2 model."""
import unittest
from transformers import (
AutoProcessor,
Mistral3Config,
is_bitsandbytes_available,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_bitsandbytes,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
Mistral3ForConditionalGeneration,
)
if is_bitsandbytes_available():
from transformers import BitsAndBytesConfig
class Mistral3VisionText2TextModelTester:
def __init__(
self,
parent,
batch_size=3,
seq_length=7,
image_seq_length=4,
vision_feature_layer=-1,
ignore_index=-100,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_index=1,
num_channels=3,
image_size=30,
model_type="mistral3",
is_training=True,
text_config={
"model_type": "mistral",
"vocab_size": 99,
"attention_dropout": 0.0,
"hidden_act": "silu",
"hidden_size": 32,
"initializer_range": 0.02,
"intermediate_size": 37,
"max_position_embeddings": 512,
"num_attention_heads": 4,
"num_hidden_layers": 2,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"bos_token_id": 0,
"eos_token_id": 0,
"pad_token_id": 0,
},
vision_config={
"model_type": "pixtral",
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"image_size": 30,
"patch_size": 6,
"num_channels": 3,
"hidden_act": "gelu",
},
):
self.parent = parent
self.ignore_index = ignore_index
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.image_token_index = image_token_index
self.model_type = model_type
self.text_config = text_config
self.vision_config = vision_config
self.batch_size = batch_size
self.vision_feature_layer = vision_feature_layer
self.is_training = is_training
self.image_seq_length = image_seq_length
self.num_channels = num_channels
self.image_size = image_size
self.seq_length = seq_length + self.image_seq_length
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
def get_config(self):
return Mistral3Config(
text_config=self.text_config,
vision_config=self.vision_config,
model_type=self.model_type,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
image_token_index=self.image_token_index,
image_seq_length=self.image_seq_length,
vision_feature_layer=self.vision_feature_layer,
)
def prepare_config_and_inputs(self):
config = self.get_config()
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
image_sizes = torch.tensor(
[[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device
)
# input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.image_token_index] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"image_sizes": image_sizes,
}
return config, inputs_dict
def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
model = Mistral3ForConditionalGeneration(config=config)
model.to(torch_device)
model.half()
model.eval()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask):
config.torch_dtype = torch.float16
model = Mistral3ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch
class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"image-text-to-text": Mistral3ForConditionalGeneration,
}
if is_torch_available()
else {}
)
_is_composite = True
test_headmasking = False
test_pruning = False
def setUp(self):
self.model_tester = Mistral3VisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Mistral3Config, has_text_modality=False)
def test_config(self):
# overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077
# TODO: avoid overwritten once there is a better fix for #36077
def check_config_can_be_init_without_params():
config = self.config_tester.config_class()
self.config_tester.parent.assertIsNotNone(config)
self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params
self.config_tester.run_common_tests()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_eager_matches_sdpa_generate(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_flash_attn_2_from_config(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_flash_attn_2_inference_equivalence(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
@unittest.skip("Pixtral does not support attention interfaces.")
def test_sdpa_can_dispatch_on_flash(self):
pass
@slow
@require_torch_gpu
class Mistral3IntegrationTest(unittest.TestCase):
def setUp(self):
self.model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_mistral3_integration_generate_text_only(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Write a haiku"},
],
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "Sure, here's a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace."
self.assertEqual(decoded_output, expected_output)
def test_mistral3_integration_generate(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{"type": "text", "text": "Describe this image"},
],
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"
self.assertEqual(decoded_output, expected_output)
def test_mistral3_integration_batched_generate(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{"type": "text", "text": "Write a haiku for this image"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Describe this image"},
],
},
],
]
inputs = processor.apply_chat_template(
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's mirror gleams,\nWhispering pines"
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
@require_bitsandbytes
def test_mistral3_integration_batched_generate_multi_image(self):
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = Mistral3ForConditionalGeneration.from_pretrained(
self.model_checkpoint, quantization_config=quantization_config
)
# Prepare inputs
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{"type": "text", "text": "Write a haiku for this image"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
},
{
"type": "text",
"text": "These images depict two different landmarks. Can you identify them?",
},
],
},
],
]
inputs = processor.apply_chat_template(
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n"
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = "These images depict two different landmarks. Can you identify them?Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)

View File

@@ -0,0 +1,293 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import requests
from transformers import PixtralProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
@require_vision
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
processor_class = PixtralProcessor
@classmethod
def setUpClass(cls):
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = PixtralProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
processor.save_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 198
image_token_index = 10
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)
def test_processor_with_single_image(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.patch_size = 6
# Test passing in an image
inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt")
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
# fmt: off
input_ids = inputs_image["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in a url
inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt")
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
# fmt: off
input_ids = inputs_url["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing inputs as a single list
inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt")
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
# fmt: off
self.assertEqual(
inputs_image["input_ids"][0].tolist(),
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test as nested single list
inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt")
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
# fmt: off
self.assertEqual(
inputs_image["input_ids"][0].tolist(),
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_with_multiple_images_single_list(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.patch_size = 6
# Test passing in an image
inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
# fmt: off
input_ids = inputs_image["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in a url
inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
# fmt: off
input_ids = inputs_url["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in as a nested list
inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt")
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
# fmt: off
self.assertEqual(
inputs_url["input_ids"][0].tolist(),
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_with_multiple_images_multiple_lists(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
]
processor.tokenizer.pad_token = "</s>"
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.patch_size = 6
# Test passing in an image
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 2)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
# fmt: off
input_ids = inputs_image["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing in a url
inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 2)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
# fmt: off
input_ids = inputs_url["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Test passing as a single flat list
inputs_image = processor(
text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True
)
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
# fmt: off
self.assertEqual(
inputs_image["input_ids"][0].tolist(),
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
def test_processor_returns_full_length_batches(self):
# to avoid https://github.com/huggingface/transformers/issues/34204
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
] * 5
processor.tokenizer.pad_token = "</s>"
image_inputs = [[self.image_0]] * 5
# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.patch_size = 6
# Test passing in an image
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 5)

View File

@@ -109,8 +109,8 @@ class PixtralImageProcessingTester:
ratio = max(height / max_height, width / max_width) ratio = max(height / max_height, width / max_width)
if ratio > 1: if ratio > 1:
height = int(np.ceil(height / ratio)) height = int(np.floor(height / ratio))
width = int(np.ceil(width / ratio)) width = int(np.floor(width / ratio))
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
num_height_tokens = (height - 1) // patch_height + 1 num_height_tokens = (height - 1) // patch_height + 1