diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5c4a643507..79f8eb3d49 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -529,6 +529,8 @@ title: MegatronGPT2 - local: model_doc/mistral title: Mistral + - local: model_doc/mistral3 + title: Mistral3 - local: model_doc/mixtral title: Mixtral - local: model_doc/mluke diff --git a/docs/source/en/model_doc/mistral3.md b/docs/source/en/model_doc/mistral3.md new file mode 100644 index 0000000000..5b607294f6 --- /dev/null +++ b/docs/source/en/model_doc/mistral3.md @@ -0,0 +1,234 @@ + + +# Mistral3 + +## Overview + +Building upon Mistral Small 3 (2501), Mistral Small 3.1 (2503) adds state-of-the-art vision understanding and enhances long context capabilities up to 128k tokens without compromising text performance. With 24 billion parameters, this model achieves top-tier capabilities in both text and vision tasks. + +It is ideal for: +- Fast-response conversational agents. +- Low-latency function calling. +- Subject matter experts via fine-tuning. +- Local inference for hobbyists and organizations handling sensitive data. +- Programming and math reasoning. +- Long document understanding. +- Visual understanding. + +This model was contributed by [cyrilvallez](https://huggingface.co/cyrilvallez) and [yonigozlan](https://huggingface.co/yonigozlan). + +The original code can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/pixtral.py) and [here](https://github.com/mistralai/mistral-common). + +## Usage example + +### Inference with Pipeline + +Here is how you can use the `image-text-to-text` pipeline to perform inference with the `Mistral3` models in just a few lines of code: +```python +>>> from transformers import pipeline + +>>> messages = [ +... { +... "role": "user", +... "content": [ +... { +... "type": "image", +... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", +... }, +... {"type": "text", "text": "Describe this image."}, +... ], +... }, +... ] + +>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16) +>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False) +>>> outputs[0]["generated_text"] +'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a' +``` +### Inference on a single image + +This example demonstrates how to perform inference on a single image with the Mistral3 models using chat templates. + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch + +>>> torch_device = "cuda" +>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" +>>> processor = AutoProcessor.from_pretrained(model_checkpoint) +>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16) + +>>> messages = [ +... { +... "role": "user", +... "content": [ +... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, +... {"type": "text", "text": "Describe this image"}, +... ], +... } +... ] + +>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16) + +>>> generate_ids = model.generate(**inputs, max_new_tokens=20) +>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + +>>> decoded_output +"The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"... +``` + +### Text-only generation +This example shows how to generate text using the Mistral3 model without providing any image input. + + +````python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch + +>>> torch_device = "cuda" +>>> model_checkpoint = ".mistralai/Mistral-Small-3.1-24B-Instruct-2503" +>>> processor = AutoProcessor.from_pretrained(model_checkpoint) +>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16) + +>>> SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, always end your accurate response with an ASCII drawing of a cat." +>>> user_prompt = "Give me 5 non-formal ways to say 'See you later' in French." + +>>> messages = [ +... {"role": "system", "content": SYSTEM_PROMPT}, +... {"role": "user", "content": user_prompt}, +... ] + +>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +>>> inputs = processor(text=text, return_tensors="pt").to(0, dtype=torch.float16) +>>> generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False) +>>> decoded_output = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0] + +>>> print(decoded_output) +"1. À plus tard! +2. Salut, à plus! +3. À toute! +4. À la prochaine! +5. Je me casse, à plus! + +``` + /\_/\ +( o.o ) + > ^ < +```" +```` + +### Batched image and text inputs +Mistral3 models also support batched image and text inputs. + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch + +>>> torch_device = "cuda" +>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" +>>> processor = AutoProcessor.from_pretrained(model_checkpoint) +>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16) + +>>> messages = [ +... [ +... { +... "role": "user", +... "content": [ +... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, +... {"type": "text", "text": "Write a haiku for this image"}, +... ], +... }, +... ], +... [ +... { +... "role": "user", +... "content": [ +... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, +... {"type": "text", "text": "Describe this image"}, +... ], +... }, +... ], +... ] + + +>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16) + +>>> output = model.generate(**inputs, max_new_tokens=25) + +>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True) +>>> decoded_outputs +["Write a haiku for this imageCalm waters reflect\nWhispers of the forest's breath\nPeace on wooden path" +, "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"] +``` + +### Batched multi-image input and quantization with BitsAndBytes +This implementation of the Mistral3 models supports batched text-images inputs with different number of images for each text. +This example also how to use `BitsAndBytes` to load the model in 4bit quantization. + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig +>>> import torch + +>>> torch_device = "cuda" +>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" +>>> processor = AutoProcessor.from_pretrained(model_checkpoint) +>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True) +>>> model = AutoModelForImageTextToText.from_pretrained( +... model_checkpoint, quantization_config=quantization_config +... ) + +>>> messages = [ +...     [ +...         { +...             "role": "user", +...             "content": [ +...                 {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, +...                 {"type": "text", "text": "Write a haiku for this image"}, +...             ], +...         }, +...     ], +...     [ +...         { +...             "role": "user", +...             "content": [ +...                 {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"}, +...                 {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"}, +...                 {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"}, +...             ], +...         }, +...     ], +>>> ] + +>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16) + +>>> output = model.generate(**inputs, max_new_tokens=25) + +>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True) +>>> decoded_outputs +["Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n", "These images depict two different landmarks. Can you identify them? Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."] +``` + + +## Mistral3Config + +[[autodoc]] Mistral3Config + + +## Mistral3ForConditionalGeneration + +[[autodoc]] Mistral3ForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index da8b1cacaa..4c6ca0fe1e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -613,6 +613,7 @@ _import_structure = { ], "models.mimi": ["MimiConfig"], "models.mistral": ["MistralConfig"], + "models.mistral3": ["Mistral3Config"], "models.mixtral": ["MixtralConfig"], "models.mllama": [ "MllamaConfig", @@ -2940,6 +2941,12 @@ else: "MistralPreTrainedModel", ] ) + _import_structure["models.mistral3"].extend( + [ + "Mistral3ForConditionalGeneration", + "Mistral3PreTrainedModel", + ] + ) _import_structure["models.mixtral"].extend( [ "MixtralForCausalLM", @@ -5788,6 +5795,7 @@ if TYPE_CHECKING: MimiConfig, ) from .models.mistral import MistralConfig + from .models.mistral3 import Mistral3Config from .models.mixtral import MixtralConfig from .models.mllama import ( MllamaConfig, @@ -7844,6 +7852,10 @@ if TYPE_CHECKING: MistralModel, MistralPreTrainedModel, ) + from .models.mistral3 import ( + Mistral3ForConditionalGeneration, + Mistral3PreTrainedModel, + ) from .models.mixtral import ( MixtralForCausalLM, MixtralForQuestionAnswering, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index c30b97ade7..3a72fca91e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -169,6 +169,7 @@ from . import ( mgp_str, mimi, mistral, + mistral3, mixtral, mllama, mluke, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3c6b849d8c..0969976937 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -192,6 +192,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("mgp-str", "MgpstrConfig"), ("mimi", "MimiConfig"), ("mistral", "MistralConfig"), + ("mistral3", "Mistral3Config"), ("mixtral", "MixtralConfig"), ("mllama", "MllamaConfig"), ("mobilebert", "MobileBertConfig"), @@ -537,6 +538,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("mgp-str", "MGP-STR"), ("mimi", "Mimi"), ("mistral", "Mistral"), + ("mistral3", "Mistral3"), ("mixtral", "Mixtral"), ("mllama", "Mllama"), ("mluke", "mLUKE"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 336d3bf116..b1db245839 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -111,6 +111,7 @@ else: ("mask2former", ("Mask2FormerImageProcessor",)), ("maskformer", ("MaskFormerImageProcessor",)), ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("mllama", ("MllamaImageProcessor",)), ("mobilenet_v1", ("MobileNetV1ImageProcessor",)), ("mobilenet_v2", ("MobileNetV2ImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index aa0d120b7f..90dee0eb58 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -361,6 +361,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("mamba2", "Mamba2ForCausalLM"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForPreTraining"), + ("mistral3", "Mistral3ForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"), ("mobilebert", "MobileBertForPreTraining"), ("mpnet", "MPNetForMaskedLM"), @@ -802,6 +803,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( ("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + ("mistral3", "Mistral3ForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), @@ -839,6 +841,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + ("mistral3", "Mistral3ForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c65219b0bc..a318d443fb 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -84,6 +84,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("markuplm", "MarkupLMProcessor"), ("mctct", "MCTCTProcessor"), ("mgp-str", "MgpstrProcessor"), + ("mistral3", "PixtralProcessor"), ("mllama", "MllamaProcessor"), ("moonshine", "Wav2Vec2Processor"), ("oneformer", "OneFormerProcessor"), diff --git a/src/transformers/models/mistral3/__init__.py b/src/transformers/models/mistral3/__init__.py new file mode 100644 index 0000000000..11a9fcbdc4 --- /dev/null +++ b/src/transformers/models/mistral3/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mistral3 import * + from .modeling_mistral3 import * + from .processing_mistral3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/mistral3/configuration_mistral3.py b/src/transformers/models/mistral3/configuration_mistral3.py new file mode 100644 index 0000000000..e7b27d5722 --- /dev/null +++ b/src/transformers/models/mistral3/configuration_mistral3.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class Mistral3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Mistral3ForConditionalGeneration`]. It is used to instantiate an + Mistral3 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + [mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `PixtralVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MistralConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 10): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -1): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + multimodal_projector_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the multimodal projector. + spatial_merge_size (`int`, *optional*, defaults to 2): + The downsampling factor for the spatial merge operation. + + Example: + + ```python + >>> from transformers import Mistral3ForConditionalGeneration, Mistral3Config, PixtralVisionConfig, MistralConfig + + >>> # Initializing a Pixtral-vision config + >>> vision_config = PixtralVisionConfig() + + >>> # Initializing a Mistral config + >>> text_config = MistralConfig() + + >>> # Initializing a Mistral3 configuration + >>> configuration = Mistral3Config(vision_config, text_config) + + >>> # Initializing a model from the mistral3.1 configuration + >>> model = Mistral3ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral3" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + is_composition = True + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=10, + projector_hidden_act="gelu", + vision_feature_layer=-1, + multimodal_projector_bias=False, + spatial_merge_size=2, + **kwargs, + ): + super().__init__(**kwargs) + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config["model_type"] if "model_type" in vision_config else "pixtral" + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["pixtral"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=1540, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + head_dim=64, + hidden_act="gelu", + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["mistral"]( + attention_dropout=0.0, + head_dim=128, + hidden_act="silu", + hidden_size=5120, + initializer_range=0.02, + intermediate_size=32768, + max_position_embeddings=131072, + model_type="mistral", + num_attention_heads=32, + num_hidden_layers=40, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_theta=1000000000.0, + sliding_window=None, + use_cache=True, + vocab_size=131072, + ) + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + self.spatial_merge_size = spatial_merge_size + + +__all__ = ["Mistral3Config"] diff --git a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py new file mode 100644 index 0000000000..11b2d18f04 --- /dev/null +++ b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py @@ -0,0 +1,241 @@ +# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import re + +import torch +from safetensors.torch import load_file + +from transformers import ( + Mistral3Config, + Mistral3ForConditionalGeneration, + MistralConfig, + PixtralImageProcessorFast, + PixtralProcessor, + PixtralVisionConfig, +) +from transformers.integrations.mistral import convert_tekken_tokenizer + + +# fmt: off +STATE_DICT_MAPPING = { + # Text model keys + r"^output.weight": r"language_model.lm_head.weight", + r"^norm.weight": r"language_model.model.norm.weight", + r"^tok_embeddings.weight": r"language_model.model.embed_tokens.weight", + r"^layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"^layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + r"^layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight", + r"^layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", + r"^layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", + r"^layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", + + # Vision model keys + r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", + r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", + r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"vision_tower.transformer.layers.\1.attention.\2_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + r"^vision_language_adapter.w_in": r"multi_modal_projector.linear_1", + r"^vision_language_adapter.w_out": r"multi_modal_projector.linear_2", + r"^vision_encoder.ln_pre.weight": r"vision_tower.ln_pre.weight", + r"^vision_encoder.patch_conv.weight": r"vision_tower.patch_conv.weight", + r"^patch_merger.merging_layer.weight": r"multi_modal_projector.patch_merger.merging_layer.weight", + r"^pre_mm_projector_norm.weight": r"multi_modal_projector.norm.weight", +} +# fmt: on + + +def map_old_key_to_new(old_key): + """Map of a key of the original state dict to the equivalent key in HF format""" + for pattern, replacement in STATE_DICT_MAPPING.items(): + new_key, n_replace = re.subn(pattern, replacement, old_key) + # Early exit of the loop + if n_replace > 0: + return new_key + + raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).") + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def permute_for_rope(tensor, n_heads, dim1, dim2): + """Permute the weights for the ROPE formulation.""" + tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + tensor = tensor.transpose(1, 2) + tensor = tensor.reshape(dim1, dim2) + return tensor + + +def convert_state_dict(original_state_dict: dict, config: MistralConfig): + """Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case).""" + new_dict = {} + + for old_key, tensor in original_state_dict.items(): + new_key = map_old_key_to_new(old_key) + + if "vision" in old_key: + num_attention_heads = config.vision_config.num_attention_heads + num_key_value_heads = num_attention_heads + hidden_size = config.vision_config.hidden_size + head_dim = config.vision_config.head_dim + key_value_dim = head_dim * num_attention_heads + query_dim = head_dim * num_attention_heads + else: + num_attention_heads = config.text_config.num_attention_heads + hidden_size = config.text_config.hidden_size + head_dim = config.text_config.head_dim + num_key_value_heads = config.text_config.num_key_value_heads + key_value_dim = head_dim * num_key_value_heads + query_dim = head_dim * num_attention_heads + + if "q_proj" in new_key: + tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size) + elif "k_proj" in new_key: + tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size) + + new_dict[new_key] = tensor + return new_dict + + +def convert_config(original_config: dict, max_position_embeddings: int = 131072): + original_vision_config = original_config.pop("vision_encoder") + original_text_config = original_config + + # Text config + text_key_mapping = { + "hidden_size": "dim", + "num_hidden_layers": "n_layers", + "intermediate_size": "hidden_dim", + "num_attention_heads": "n_heads", + "num_key_value_heads": "n_kv_heads", + "rms_norm_eps": "norm_eps", + } + similar_text_keys_to_keep = [ + "head_dim", + "vocab_size", + "rope_theta", + ] + new_text_config_kwargs = {k: original_text_config[v] for k, v in text_key_mapping.items()} + new_text_config_kwargs.update({k: v for k, v in original_text_config.items() if k in similar_text_keys_to_keep}) + # These are not always defined depending on `params.json` + new_text_config_kwargs["sliding_window"] = original_text_config.get("sliding_window", None) + new_text_config_kwargs["max_position_embeddings"] = original_text_config.get( + "max_seq_len", max_position_embeddings + ) + # This may sometimes be a string in `params.json` + if new_text_config_kwargs["sliding_window"] is not None: + new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"]) + new_text_config = MistralConfig(**new_text_config_kwargs) + + # Vision config + new_vision_config = original_vision_config + adapter_bias = new_vision_config.pop("adapter_bias", False) + _ = new_vision_config.pop("mm_projector_id", None) + _ = new_vision_config.pop("add_pre_mm_projector_layer_norm", None) + spatial_merge_size = new_vision_config.pop("spatial_merge_size") + image_token_id = new_vision_config.pop("image_token_id", 10) + _ = new_vision_config.pop("image_break_token_id", 12) + _ = new_vision_config.pop("image_end_token_id", 13) + _ = new_vision_config.pop("max_image_size") + new_vision_config = PixtralVisionConfig(**new_vision_config) + + new_config = Mistral3Config( + vision_config=new_vision_config, + text_config=new_text_config, + multimodal_projector_bias=adapter_bias, + image_token_index=image_token_id, + spatial_merge_size=spatial_merge_size, + vision_feature_layer=-1, + ) + return new_config + + +def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int): + """Convert the model and save it (this implicitly save the config as well).""" + params = read_json(os.path.join(input_dir, "params.json")) + config = convert_config(params, max_position_embeddings) + + full_state_dict = {} + # The model may be split between different files, but a single nn.Module is always fully present in a single file + shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")] + for shard_file in shards: + original_state_dict = load_file(os.path.join(input_dir, shard_file)) + new_dict = convert_state_dict(original_state_dict, config) + full_state_dict.update(new_dict) + + # Load weights into model and resave them + with torch.device("meta"): + model = Mistral3ForConditionalGeneration(config) + model.load_state_dict(full_state_dict, strict=True, assign=True) + model.save_pretrained(output_dir) + + +def convert_and_write_processor(input_dir: str, output_dir: str): + """Convert the tokenizer and save it.""" + tokenizer_file = os.path.join(input_dir, "tekken.json") + tokenizer = convert_tekken_tokenizer(tokenizer_file) + tokenizer.add_special_tokens({"pad_token": ""}) + chat_template = '{%- if messages[0]["role"] == "system" %}{%- set system_message = messages[0]["content"] %}{%- set loop_messages = messages[1:] %}\n{%- else %}{%- set loop_messages = messages %}{%- endif %}{{- bos_token }}{%- for message in loop_messages %}{%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}{{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}{%- endif %}{%- if message["role"] == "user" %}{%- if loop.last and system_message is defined %}{{- "[INST]" + system_message + "\n\n" }}{%- else %}{{ "[INST]" }}{%- endif %}{%- endif %}{%- if message["content"] is not string %}{%- for chunk in message["content"] %}{%- if chunk["type"] == "text" %}{%- if "content" in chunk %}{{- chunk["content"] }}{%- elif "text" in chunk %}{{- chunk["text"] }}{%- endif %}{%- elif chunk["type"] == "image" %}{{- "[IMG]" }}{%- else %}{{- raise_exception("Unrecognized content type!") }}{%- endif %}{%- endfor %}{%- else %}{{- message["content"] }}{%- endif %}{%- if message["role"] == "user" %}{{- "[/INST]" }}{%- elif message["role"] == "assistant" %}{{- eos_token}}{%- else %}{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}{%- endif %}{%- endfor %}' + + config = read_json(os.path.join(input_dir, "params.json")) + patch_size = config["vision_encoder"]["patch_size"] + spatial_merge_size = config["vision_encoder"]["spatial_merge_size"] + max_image_size = config["vision_encoder"]["max_image_size"] + image_processor = PixtralImageProcessorFast(patch_size=patch_size, size={"longest_edge": max_image_size}) + + processor = PixtralProcessor( + tokenizer=tokenizer, + image_processor=image_processor, + image_token="[IMG]", + patch_size=patch_size, + chat_template=chat_template, + spatial_merge_size=spatial_merge_size, + ) + + # Finally save it + processor.save_pretrained(output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_dir", + help="Location of Mistral weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--max_position_embeddings", + type=int, + default=131072, + help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).", + ) + + args = parser.parse_args() + + convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings) + convert_and_write_processor(args.input_dir, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py new file mode 100644 index 0000000000..4ded5efed6 --- /dev/null +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -0,0 +1,553 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mistral3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_mistral3 import Mistral3Config + + +_CONFIG_FOR_DOC = "Mistral3Config" + + +class Mistral3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Mistral3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Mistral3PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__(self, config: Mistral3Config): + super().__init__() + self.config = config + + hidden_size = config.vision_config.hidden_size + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = self.config.vision_config.patch_size + self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + grid = grid.view(d * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class Mistral3MultiModalProjector(nn.Module): + def __init__(self, config: Mistral3Config): + super().__init__() + self.norm = Mistral3RMSNorm(config.vision_config.hidden_size) + self.patch_merger = Mistral3PatchMerger(config) + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor): + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass +class Mistral3CausalLMOutputWithPast(ModelOutput): + """ + Base class for Mistral3 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +MISTRAL3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Mistral3Config`] or [`Mistral3VisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + MISTRAL3_START_DOCSTRING, +) +class Mistral3PreTrainedModel(PreTrainedModel): + config_class = Mistral3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Mistral3VisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + # important: this ported version of Mistral3 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Mistral3Processor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The MISTRAL3 model which consists of a vision backbone and a language model.""", + MISTRAL3_START_DOCSTRING, +) +class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin): + def __init__(self, config: Mistral3Config): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = Mistral3MultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + image_sizes: torch.Tensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_sizes (`torch.Tensor`): + Tensor containing the image sizes as returned by the processor. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + return image_features + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + image_sizes=image_sizes, + ) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Mistral3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"] diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py new file mode 100644 index 0000000000..9d1edf97bd --- /dev/null +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -0,0 +1,286 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...utils import is_torchdynamo_compiling, logging +from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration +from ..mistral.modeling_mistral import MistralRMSNorm +from .configuration_mistral3 import Mistral3Config + + +logger = logging.get_logger(__name__) + + +class Mistral3RMSNorm(MistralRMSNorm): + pass + + +class Mistral3PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__(self, config: Mistral3Config): + super().__init__() + self.config = config + + hidden_size = config.vision_config.hidden_size + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = self.config.vision_config.patch_size + self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + grid = grid.view(d * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class Mistral3MultiModalProjector(nn.Module): + def __init__(self, config: Mistral3Config): + super().__init__() + self.norm = Mistral3RMSNorm(config.vision_config.hidden_size) + self.patch_merger = Mistral3PatchMerger(config) + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor): + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + image_sizes: torch.Tensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_sizes (`torch.Tensor`): + Tensor containing the image sizes as returned by the processor. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + return image_features + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + >>> prompt = "[INST][IMG]What is the image?[/INST]" + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is the image?The image depicts two cats lying on a pink blanket." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + image_sizes=image_sizes, + ) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Mistral3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +__all__ = [ + "Mistral3PreTrainedModel", # noqa + "Mistral3ForConditionalGeneration", +] diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 969575d2e4..2cb452863a 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -128,8 +128,9 @@ def get_resize_output_image_size( if ratio > 1: # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results - height = int(math.ceil(height / ratio)) - width = int(math.ceil(width / ratio)) + # Here we use floor to ensure the image is always smaller than the given "longest_edge" + height = int(math.floor(height / ratio)) + width = int(math.floor(width / ratio)) num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) return num_height_tokens * patch_height, num_width_tokens * patch_width diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index d6130699fd..66da1bf9f7 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -64,6 +64,8 @@ class PixtralProcessor(ProcessorMixin): The tokenizer is a required input. patch_size (`int`, *optional*, defaults to 16): Patch size from the vision tower. + spatial_merge_size (`int`, *optional*, defaults to 1): + The downsampling factor for the spatial merge operation. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `"[IMG]"`): @@ -78,6 +80,7 @@ class PixtralProcessor(ProcessorMixin): valid_kwargs = [ "chat_template", "patch_size", + "spatial_merge_size", "image_token", "image_break_token", "image_end_token", @@ -90,6 +93,7 @@ class PixtralProcessor(ProcessorMixin): image_processor=None, tokenizer=None, patch_size: int = 16, + spatial_merge_size: int = 1, chat_template=None, image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases image_break_token="[IMG_BREAK]", @@ -97,6 +101,7 @@ class PixtralProcessor(ProcessorMixin): **kwargs, ): self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size self.image_token = image_token self.image_break_token = image_break_token self.image_end_token = image_end_token @@ -187,8 +192,8 @@ class PixtralProcessor(ProcessorMixin): for sample in text: while self.image_token in sample: height, width = next(image_sizes) - num_height_tokens = height // self.patch_size - num_width_tokens = width // self.patch_size + num_height_tokens = height // (self.patch_size * self.spatial_merge_size) + num_width_tokens = width // (self.patch_size * self.spatial_merge_size) replace_tokens = [ [self.image_token] * num_width_tokens + [self.image_break_token] ] * num_height_tokens diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 76bcf6fc1c..c4532ead93 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6392,6 +6392,20 @@ class MistralPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Mistral3ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Mistral3PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MixtralForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index daa9e2f70c..e8e14b0497 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -125,6 +125,7 @@ VLM_CLASS_NAMES = [ "qwen2_5_vl", "ayavision", "gemma3", + "mistral3", ] diff --git a/tests/models/mistral3/__init__.py b/tests/models/mistral3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/mistral3/test_modeling_mistral3.py b/tests/models/mistral3/test_modeling_mistral3.py new file mode 100644 index 0000000000..d6f225b561 --- /dev/null +++ b/tests/models/mistral3/test_modeling_mistral3.py @@ -0,0 +1,482 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch GotOcr2 model.""" + +import unittest + +from transformers import ( + AutoProcessor, + Mistral3Config, + is_bitsandbytes_available, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_bitsandbytes, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + Mistral3ForConditionalGeneration, + ) + + +if is_bitsandbytes_available(): + from transformers import BitsAndBytesConfig + + +class Mistral3VisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + image_seq_length=4, + vision_feature_layer=-1, + ignore_index=-100, + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + image_token_index=1, + num_channels=3, + image_size=30, + model_type="mistral3", + is_training=True, + text_config={ + "model_type": "mistral", + "vocab_size": 99, + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 32, + "initializer_range": 0.02, + "intermediate_size": 37, + "max_position_embeddings": 512, + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000000.0, + "sliding_window": None, + "bos_token_id": 0, + "eos_token_id": 0, + "pad_token_id": 0, + }, + vision_config={ + "model_type": "pixtral", + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "image_size": 30, + "patch_size": 6, + "num_channels": 3, + "hidden_act": "gelu", + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.image_token_index = image_token_index + self.model_type = model_type + self.text_config = text_config + self.vision_config = vision_config + self.batch_size = batch_size + self.vision_feature_layer = vision_feature_layer + self.is_training = is_training + self.image_seq_length = image_seq_length + self.num_channels = num_channels + self.image_size = image_size + self.seq_length = seq_length + self.image_seq_length + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + + def get_config(self): + return Mistral3Config( + text_config=self.text_config, + vision_config=self.vision_config, + model_type=self.model_type, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + image_token_index=self.image_token_index, + image_seq_length=self.image_seq_length, + vision_feature_layer=self.vision_feature_layer, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + image_sizes = torch.tensor( + [[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device + ) + + # input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.image_token_index] = self.pad_token_id + input_ids[:, : self.image_seq_length] = self.image_token_index + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "image_sizes": image_sizes, + } + return config, inputs_dict + + def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask): + model = Mistral3ForConditionalGeneration(config=config) + model.to(torch_device) + model.half() + model.eval() + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask): + config.torch_dtype = torch.float16 + model = Mistral3ForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-text-to-text": Mistral3ForConditionalGeneration, + } + if is_torch_available() + else {} + ) + _is_composite = True + test_headmasking = False + test_pruning = False + + def setUp(self): + self.model_tester = Mistral3VisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Mistral3Config, has_text_modality=False) + + def test_config(self): + # overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077 + # TODO: avoid overwritten once there is a better fix for #36077 + def check_config_can_be_init_without_params(): + config = self.config_tester.config_class() + self.config_tester.parent.assertIsNotNone(config) + + self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params + self.config_tester.run_common_tests() + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + + @unittest.skip(reason="Compile not yet supported because in LLava models") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_eager_matches_fa2_generate(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_eager_matches_sdpa_generate(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flash_attn_2_from_config(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flash_attn_2_inference_equivalence(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_sdpa_can_dispatch_on_flash(self): + pass + + +@slow +@require_torch_gpu +class Mistral3IntegrationTest(unittest.TestCase): + def setUp(self): + self.model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def test_mistral3_integration_generate_text_only(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = Mistral3ForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16 + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Write a haiku"}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch_device, dtype=torch.bfloat16) + + with torch.no_grad(): + generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False) + decoded_output = processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "Sure, here's a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace." + self.assertEqual(decoded_output, expected_output) + + def test_mistral3_integration_generate(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = Mistral3ForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16 + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, + {"type": "text", "text": "Describe this image"}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch_device, dtype=torch.bfloat16) + with torch.no_grad(): + generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False) + decoded_output = processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an" + self.assertEqual(decoded_output, expected_output) + + def test_mistral3_integration_batched_generate(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = Mistral3ForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16 + ) + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + {"type": "text", "text": "Write a haiku for this image"}, + ], + }, + ], + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Describe this image"}, + ], + }, + ], + ] + + inputs = processor.apply_chat_template( + messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(model.device, dtype=torch.bfloat16) + + output = model.generate(**inputs, do_sample=False, max_new_tokens=25) + + # Check first output + decoded_output = processor.decode(output[0], skip_special_tokens=True) + expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's mirror gleams,\nWhispering pines" + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) + + # Check second output + decoded_output = processor.decode(output[1], skip_special_tokens=True) + expected_output = "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese" + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) + + @require_bitsandbytes + def test_mistral3_integration_batched_generate_multi_image(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + model = Mistral3ForConditionalGeneration.from_pretrained( + self.model_checkpoint, quantization_config=quantization_config + ) + + # Prepare inputs + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + {"type": "text", "text": "Write a haiku for this image"}, + ], + }, + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + { + "type": "image", + "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + }, + { + "type": "text", + "text": "These images depict two different landmarks. Can you identify them?", + }, + ], + }, + ], + ] + inputs = processor.apply_chat_template( + messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(model.device, dtype=torch.float16) + + output = model.generate(**inputs, do_sample=False, max_new_tokens=25) + + # Check first output + decoded_output = processor.decode(output[0], skip_special_tokens=True) + expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n" + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) + + # Check second output + decoded_output = processor.decode(output[1], skip_special_tokens=True) + expected_output = "These images depict two different landmarks. Can you identify them?Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City." + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) diff --git a/tests/models/mistral3/test_processor_mistral3.py b/tests/models/mistral3/test_processor_mistral3.py new file mode 100644 index 0000000000..da9c9a759a --- /dev/null +++ b/tests/models/mistral3/test_processor_mistral3.py @@ -0,0 +1,293 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import requests + +from transformers import PixtralProcessor +from transformers.testing_utils import require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + +@require_vision +class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + """This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3.""" + + processor_class = PixtralProcessor + + @classmethod + def setUpClass(cls): + cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg" + cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw) + cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" + cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw) + cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" + cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw) + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = PixtralProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_chat_template(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + expected_prompt = "[INST][IMG]What is shown in this image?[/INST]" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) + self.assertEqual(expected_prompt, formatted_prompt) + + def test_image_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + # Important to check with non square image + image = torch.randint(0, 2, (3, 500, 316)) + expected_image_tokens = 198 + image_token_index = 10 + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens) + + def test_processor_with_single_image(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:" + + # Make small for checking image token expansion + processor.image_processor.size = {"longest_edge": 30} + processor.patch_size = 6 + + # Test passing in an image + inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt") + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 1) + self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30])) + + # fmt: off + input_ids = inputs_image["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:" + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in a url + inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt") + self.assertIn("input_ids", inputs_url) + self.assertTrue(len(inputs_url["input_ids"]) == 1) + self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30])) + + # fmt: off + input_ids = inputs_url["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:" + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing inputs as a single list + inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt") + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30])) + + # fmt: off + self.assertEqual( + inputs_image["input_ids"][0].tolist(), + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test as nested single list + inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt") + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30])) + + # fmt: off + self.assertEqual( + inputs_image["input_ids"][0].tolist(), + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + def test_processor_with_multiple_images_single_list(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:" + + # Make small for checking image token expansion + processor.image_processor.size = {"longest_edge": 30} + processor.patch_size = 6 + + # Test passing in an image + inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt") + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 1) + self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30])) + + # fmt: off + input_ids = inputs_image["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in a url + inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt") + self.assertIn("input_ids", inputs_url) + self.assertTrue(len(inputs_url["input_ids"]) == 1) + self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30])) + + # fmt: off + input_ids = inputs_url["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in as a nested list + inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt") + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30])) + + # fmt: off + self.assertEqual( + inputs_url["input_ids"][0].tolist(), + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + def test_processor_with_multiple_images_multiple_lists(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + prompt_string = [ + "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:", + "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", + ] + processor.tokenizer.pad_token = "" + image_inputs = [[self.image_0, self.image_1], [self.image_2]] + + # Make small for checking image token expansion + processor.image_processor.size = {"longest_edge": 30} + processor.patch_size = 6 + + # Test passing in an image + inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 2) + self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30])) + + # fmt: off + input_ids = inputs_image["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in a url + inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + self.assertIn("input_ids", inputs_url) + self.assertTrue(len(inputs_url["input_ids"]) == 2) + self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30])) + + # fmt: off + input_ids = inputs_url["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing as a single flat list + inputs_image = processor( + text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True + ) + self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30])) + + # fmt: off + self.assertEqual( + inputs_image["input_ids"][0].tolist(), + [1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + def test_processor_returns_full_length_batches(self): + # to avoid https://github.com/huggingface/transformers/issues/34204 + processor = self.processor_class.from_pretrained(self.tmpdirname) + prompt_string = [ + "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", + ] * 5 + processor.tokenizer.pad_token = "" + image_inputs = [[self.image_0]] * 5 + + # Make small for checking image token expansion + processor.image_processor.size = {"longest_edge": 30} + processor.patch_size = 6 + + # Test passing in an image + inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 5) diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py index a2a0243724..3f9deded6d 100644 --- a/tests/models/pixtral/test_image_processing_pixtral.py +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -109,8 +109,8 @@ class PixtralImageProcessingTester: ratio = max(height / max_height, width / max_width) if ratio > 1: - height = int(np.ceil(height / ratio)) - width = int(np.ceil(width / ratio)) + height = int(np.floor(height / ratio)) + width = int(np.floor(width / ratio)) patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] num_height_tokens = (height - 1) // patch_height + 1