Add Mistral3 (#36790)
* initial start * style and dummies * Create convert_mistral3_weights_to_hf.py * update * typo * typo * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * up * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * update * update * Update image_processing_mistral3.py * Update convert_mistral3_weights_to_hf.py * fix patch merger * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * up * update modular to fit * style * Update convert_mistral3_weights_to_hf.py * typo * Update modular_mistral3.py * simplify a lot all shape shenanigans * simplify * add working test processor * Add partially working common modeling tests * All tests working and remove mistral3 image processors * add docs and fixup * fix inference with image size >1540 * 🚨fix test image proc pixtral * Remove vision_feature_select_strategy * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * clean * fix test checkpoints * Update test_modeling_mistral3.py * Update test_modeling_mistral3.py * style * Use Pixtral processor * up * finish cleaning processor to use pixtral directly * Update __init__.py * Update processing_pixtral.py * doc * Update __init__.py * Update mistral3.md * Update _toctree.yml --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com>
This commit is contained in:
@@ -529,6 +529,8 @@
|
||||
title: MegatronGPT2
|
||||
- local: model_doc/mistral
|
||||
title: Mistral
|
||||
- local: model_doc/mistral3
|
||||
title: Mistral3
|
||||
- local: model_doc/mixtral
|
||||
title: Mixtral
|
||||
- local: model_doc/mluke
|
||||
|
||||
234
docs/source/en/model_doc/mistral3.md
Normal file
234
docs/source/en/model_doc/mistral3.md
Normal file
@@ -0,0 +1,234 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mistral3
|
||||
|
||||
## Overview
|
||||
|
||||
Building upon Mistral Small 3 (2501), Mistral Small 3.1 (2503) adds state-of-the-art vision understanding and enhances long context capabilities up to 128k tokens without compromising text performance. With 24 billion parameters, this model achieves top-tier capabilities in both text and vision tasks.
|
||||
|
||||
It is ideal for:
|
||||
- Fast-response conversational agents.
|
||||
- Low-latency function calling.
|
||||
- Subject matter experts via fine-tuning.
|
||||
- Local inference for hobbyists and organizations handling sensitive data.
|
||||
- Programming and math reasoning.
|
||||
- Long document understanding.
|
||||
- Visual understanding.
|
||||
|
||||
This model was contributed by [cyrilvallez](https://huggingface.co/cyrilvallez) and [yonigozlan](https://huggingface.co/yonigozlan).
|
||||
|
||||
The original code can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/pixtral.py) and [here](https://github.com/mistralai/mistral-common).
|
||||
|
||||
## Usage example
|
||||
|
||||
### Inference with Pipeline
|
||||
|
||||
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `Mistral3` models in just a few lines of code:
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {
|
||||
... "type": "image",
|
||||
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||
... },
|
||||
... {"type": "text", "text": "Describe this image."},
|
||||
... ],
|
||||
... },
|
||||
... ]
|
||||
|
||||
>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16)
|
||||
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
||||
>>> outputs[0]["generated_text"]
|
||||
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
|
||||
```
|
||||
### Inference on a single image
|
||||
|
||||
This example demonstrates how to perform inference on a single image with the Mistral3 models using chat templates.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||
... {"type": "text", "text": "Describe this image"},
|
||||
... ],
|
||||
... }
|
||||
... ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
|
||||
>>> decoded_output
|
||||
"The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"...
|
||||
```
|
||||
|
||||
### Text-only generation
|
||||
This example shows how to generate text using the Mistral3 model without providing any image input.
|
||||
|
||||
|
||||
````python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = ".mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, always end your accurate response with an ASCII drawing of a cat."
|
||||
>>> user_prompt = "Give me 5 non-formal ways to say 'See you later' in French."
|
||||
|
||||
>>> messages = [
|
||||
... {"role": "system", "content": SYSTEM_PROMPT},
|
||||
... {"role": "user", "content": user_prompt},
|
||||
... ]
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=text, return_tensors="pt").to(0, dtype=torch.float16)
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
|
||||
>>> decoded_output = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
|
||||
|
||||
>>> print(decoded_output)
|
||||
"1. À plus tard!
|
||||
2. Salut, à plus!
|
||||
3. À toute!
|
||||
4. À la prochaine!
|
||||
5. Je me casse, à plus!
|
||||
|
||||
```
|
||||
/\_/\
|
||||
( o.o )
|
||||
> ^ <
|
||||
```"
|
||||
````
|
||||
|
||||
### Batched image and text inputs
|
||||
Mistral3 models also support batched image and text inputs.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
... {"type": "text", "text": "Describe this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... ]
|
||||
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
["Write a haiku for this imageCalm waters reflect\nWhispers of the forest's breath\nPeace on wooden path"
|
||||
, "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"]
|
||||
```
|
||||
|
||||
### Batched multi-image input and quantization with BitsAndBytes
|
||||
This implementation of the Mistral3 models supports batched text-images inputs with different number of images for each text.
|
||||
This example also how to use `BitsAndBytes` to load the model in 4bit quantization.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(
|
||||
... model_checkpoint, quantization_config=quantization_config
|
||||
... )
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
>>> ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
["Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n", "These images depict two different landmarks. Can you identify them? Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."]
|
||||
```
|
||||
|
||||
|
||||
## Mistral3Config
|
||||
|
||||
[[autodoc]] Mistral3Config
|
||||
|
||||
|
||||
## Mistral3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Mistral3ForConditionalGeneration
|
||||
- forward
|
||||
@@ -613,6 +613,7 @@ _import_structure = {
|
||||
],
|
||||
"models.mimi": ["MimiConfig"],
|
||||
"models.mistral": ["MistralConfig"],
|
||||
"models.mistral3": ["Mistral3Config"],
|
||||
"models.mixtral": ["MixtralConfig"],
|
||||
"models.mllama": [
|
||||
"MllamaConfig",
|
||||
@@ -2940,6 +2941,12 @@ else:
|
||||
"MistralPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mistral3"].extend(
|
||||
[
|
||||
"Mistral3ForConditionalGeneration",
|
||||
"Mistral3PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mixtral"].extend(
|
||||
[
|
||||
"MixtralForCausalLM",
|
||||
@@ -5788,6 +5795,7 @@ if TYPE_CHECKING:
|
||||
MimiConfig,
|
||||
)
|
||||
from .models.mistral import MistralConfig
|
||||
from .models.mistral3 import Mistral3Config
|
||||
from .models.mixtral import MixtralConfig
|
||||
from .models.mllama import (
|
||||
MllamaConfig,
|
||||
@@ -7844,6 +7852,10 @@ if TYPE_CHECKING:
|
||||
MistralModel,
|
||||
MistralPreTrainedModel,
|
||||
)
|
||||
from .models.mistral3 import (
|
||||
Mistral3ForConditionalGeneration,
|
||||
Mistral3PreTrainedModel,
|
||||
)
|
||||
from .models.mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForQuestionAnswering,
|
||||
|
||||
@@ -169,6 +169,7 @@ from . import (
|
||||
mgp_str,
|
||||
mimi,
|
||||
mistral,
|
||||
mistral3,
|
||||
mixtral,
|
||||
mllama,
|
||||
mluke,
|
||||
|
||||
@@ -192,6 +192,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mgp-str", "MgpstrConfig"),
|
||||
("mimi", "MimiConfig"),
|
||||
("mistral", "MistralConfig"),
|
||||
("mistral3", "Mistral3Config"),
|
||||
("mixtral", "MixtralConfig"),
|
||||
("mllama", "MllamaConfig"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
@@ -537,6 +538,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mgp-str", "MGP-STR"),
|
||||
("mimi", "Mimi"),
|
||||
("mistral", "Mistral"),
|
||||
("mistral3", "Mistral3"),
|
||||
("mixtral", "Mixtral"),
|
||||
("mllama", "Mllama"),
|
||||
("mluke", "mLUKE"),
|
||||
|
||||
@@ -111,6 +111,7 @@ else:
|
||||
("mask2former", ("Mask2FormerImageProcessor",)),
|
||||
("maskformer", ("MaskFormerImageProcessor",)),
|
||||
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
("mllama", ("MllamaImageProcessor",)),
|
||||
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
|
||||
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
|
||||
|
||||
@@ -361,6 +361,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
@@ -802,6 +803,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
@@ -839,6 +841,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
|
||||
@@ -84,6 +84,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("markuplm", "MarkupLMProcessor"),
|
||||
("mctct", "MCTCTProcessor"),
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
("mistral3", "PixtralProcessor"),
|
||||
("mllama", "MllamaProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
|
||||
28
src/transformers/models/mistral3/__init__.py
Normal file
28
src/transformers/models/mistral3/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mistral3 import *
|
||||
from .modeling_mistral3 import *
|
||||
from .processing_mistral3 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
137
src/transformers/models/mistral3/configuration_mistral3.py
Normal file
137
src/transformers/models/mistral3/configuration_mistral3.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
class Mistral3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Mistral3ForConditionalGeneration`]. It is used to instantiate an
|
||||
Mistral3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `PixtralVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MistralConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
image_token_index (`int`, *optional*, defaults to 10):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -1):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
spatial_merge_size (`int`, *optional*, defaults to 2):
|
||||
The downsampling factor for the spatial merge operation.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Mistral3ForConditionalGeneration, Mistral3Config, PixtralVisionConfig, MistralConfig
|
||||
|
||||
>>> # Initializing a Pixtral-vision config
|
||||
>>> vision_config = PixtralVisionConfig()
|
||||
|
||||
>>> # Initializing a Mistral config
|
||||
>>> text_config = MistralConfig()
|
||||
|
||||
>>> # Initializing a Mistral3 configuration
|
||||
>>> configuration = Mistral3Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the mistral3.1 configuration
|
||||
>>> model = Mistral3ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mistral3"
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
image_token_index=10,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_layer=-1,
|
||||
multimodal_projector_bias=False,
|
||||
spatial_merge_size=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config["model_type"] = vision_config["model_type"] if "model_type" in vision_config else "pixtral"
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = CONFIG_MAPPING["pixtral"](
|
||||
intermediate_size=4096,
|
||||
hidden_size=1024,
|
||||
patch_size=14,
|
||||
image_size=1540,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
vocab_size=32000,
|
||||
head_dim=64,
|
||||
hidden_act="gelu",
|
||||
)
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["mistral"](
|
||||
attention_dropout=0.0,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
hidden_size=5120,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=32768,
|
||||
max_position_embeddings=131072,
|
||||
model_type="mistral",
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=8,
|
||||
rms_norm_eps=1e-05,
|
||||
rope_theta=1000000000.0,
|
||||
sliding_window=None,
|
||||
use_cache=True,
|
||||
vocab_size=131072,
|
||||
)
|
||||
|
||||
self.text_config = text_config
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
|
||||
|
||||
__all__ = ["Mistral3Config"]
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from transformers import (
|
||||
Mistral3Config,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MistralConfig,
|
||||
PixtralImageProcessorFast,
|
||||
PixtralProcessor,
|
||||
PixtralVisionConfig,
|
||||
)
|
||||
from transformers.integrations.mistral import convert_tekken_tokenizer
|
||||
|
||||
|
||||
# fmt: off
|
||||
STATE_DICT_MAPPING = {
|
||||
# Text model keys
|
||||
r"^output.weight": r"language_model.lm_head.weight",
|
||||
r"^norm.weight": r"language_model.model.norm.weight",
|
||||
r"^tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
|
||||
r"^layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
|
||||
r"^layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||
r"^layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight",
|
||||
r"^layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
|
||||
r"^layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
|
||||
r"^layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
|
||||
|
||||
# Vision model keys
|
||||
r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight",
|
||||
r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight",
|
||||
r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"vision_tower.transformer.layers.\1.attention.\2_proj.weight",
|
||||
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
|
||||
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
|
||||
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
|
||||
r"^vision_language_adapter.w_in": r"multi_modal_projector.linear_1",
|
||||
r"^vision_language_adapter.w_out": r"multi_modal_projector.linear_2",
|
||||
r"^vision_encoder.ln_pre.weight": r"vision_tower.ln_pre.weight",
|
||||
r"^vision_encoder.patch_conv.weight": r"vision_tower.patch_conv.weight",
|
||||
r"^patch_merger.merging_layer.weight": r"multi_modal_projector.patch_merger.merging_layer.weight",
|
||||
r"^pre_mm_projector_norm.weight": r"multi_modal_projector.norm.weight",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def map_old_key_to_new(old_key):
|
||||
"""Map of a key of the original state dict to the equivalent key in HF format"""
|
||||
for pattern, replacement in STATE_DICT_MAPPING.items():
|
||||
new_key, n_replace = re.subn(pattern, replacement, old_key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
return new_key
|
||||
|
||||
raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def permute_for_rope(tensor, n_heads, dim1, dim2):
|
||||
"""Permute the weights for the ROPE formulation."""
|
||||
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
tensor = tensor.transpose(1, 2)
|
||||
tensor = tensor.reshape(dim1, dim2)
|
||||
return tensor
|
||||
|
||||
|
||||
def convert_state_dict(original_state_dict: dict, config: MistralConfig):
|
||||
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
|
||||
new_dict = {}
|
||||
|
||||
for old_key, tensor in original_state_dict.items():
|
||||
new_key = map_old_key_to_new(old_key)
|
||||
|
||||
if "vision" in old_key:
|
||||
num_attention_heads = config.vision_config.num_attention_heads
|
||||
num_key_value_heads = num_attention_heads
|
||||
hidden_size = config.vision_config.hidden_size
|
||||
head_dim = config.vision_config.head_dim
|
||||
key_value_dim = head_dim * num_attention_heads
|
||||
query_dim = head_dim * num_attention_heads
|
||||
else:
|
||||
num_attention_heads = config.text_config.num_attention_heads
|
||||
hidden_size = config.text_config.hidden_size
|
||||
head_dim = config.text_config.head_dim
|
||||
num_key_value_heads = config.text_config.num_key_value_heads
|
||||
key_value_dim = head_dim * num_key_value_heads
|
||||
query_dim = head_dim * num_attention_heads
|
||||
|
||||
if "q_proj" in new_key:
|
||||
tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size)
|
||||
elif "k_proj" in new_key:
|
||||
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
|
||||
|
||||
new_dict[new_key] = tensor
|
||||
return new_dict
|
||||
|
||||
|
||||
def convert_config(original_config: dict, max_position_embeddings: int = 131072):
|
||||
original_vision_config = original_config.pop("vision_encoder")
|
||||
original_text_config = original_config
|
||||
|
||||
# Text config
|
||||
text_key_mapping = {
|
||||
"hidden_size": "dim",
|
||||
"num_hidden_layers": "n_layers",
|
||||
"intermediate_size": "hidden_dim",
|
||||
"num_attention_heads": "n_heads",
|
||||
"num_key_value_heads": "n_kv_heads",
|
||||
"rms_norm_eps": "norm_eps",
|
||||
}
|
||||
similar_text_keys_to_keep = [
|
||||
"head_dim",
|
||||
"vocab_size",
|
||||
"rope_theta",
|
||||
]
|
||||
new_text_config_kwargs = {k: original_text_config[v] for k, v in text_key_mapping.items()}
|
||||
new_text_config_kwargs.update({k: v for k, v in original_text_config.items() if k in similar_text_keys_to_keep})
|
||||
# These are not always defined depending on `params.json`
|
||||
new_text_config_kwargs["sliding_window"] = original_text_config.get("sliding_window", None)
|
||||
new_text_config_kwargs["max_position_embeddings"] = original_text_config.get(
|
||||
"max_seq_len", max_position_embeddings
|
||||
)
|
||||
# This may sometimes be a string in `params.json`
|
||||
if new_text_config_kwargs["sliding_window"] is not None:
|
||||
new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"])
|
||||
new_text_config = MistralConfig(**new_text_config_kwargs)
|
||||
|
||||
# Vision config
|
||||
new_vision_config = original_vision_config
|
||||
adapter_bias = new_vision_config.pop("adapter_bias", False)
|
||||
_ = new_vision_config.pop("mm_projector_id", None)
|
||||
_ = new_vision_config.pop("add_pre_mm_projector_layer_norm", None)
|
||||
spatial_merge_size = new_vision_config.pop("spatial_merge_size")
|
||||
image_token_id = new_vision_config.pop("image_token_id", 10)
|
||||
_ = new_vision_config.pop("image_break_token_id", 12)
|
||||
_ = new_vision_config.pop("image_end_token_id", 13)
|
||||
_ = new_vision_config.pop("max_image_size")
|
||||
new_vision_config = PixtralVisionConfig(**new_vision_config)
|
||||
|
||||
new_config = Mistral3Config(
|
||||
vision_config=new_vision_config,
|
||||
text_config=new_text_config,
|
||||
multimodal_projector_bias=adapter_bias,
|
||||
image_token_index=image_token_id,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
vision_feature_layer=-1,
|
||||
)
|
||||
return new_config
|
||||
|
||||
|
||||
def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int):
|
||||
"""Convert the model and save it (this implicitly save the config as well)."""
|
||||
params = read_json(os.path.join(input_dir, "params.json"))
|
||||
config = convert_config(params, max_position_embeddings)
|
||||
|
||||
full_state_dict = {}
|
||||
# The model may be split between different files, but a single nn.Module is always fully present in a single file
|
||||
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
|
||||
for shard_file in shards:
|
||||
original_state_dict = load_file(os.path.join(input_dir, shard_file))
|
||||
new_dict = convert_state_dict(original_state_dict, config)
|
||||
full_state_dict.update(new_dict)
|
||||
|
||||
# Load weights into model and resave them
|
||||
with torch.device("meta"):
|
||||
model = Mistral3ForConditionalGeneration(config)
|
||||
model.load_state_dict(full_state_dict, strict=True, assign=True)
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def convert_and_write_processor(input_dir: str, output_dir: str):
|
||||
"""Convert the tokenizer and save it."""
|
||||
tokenizer_file = os.path.join(input_dir, "tekken.json")
|
||||
tokenizer = convert_tekken_tokenizer(tokenizer_file)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
chat_template = '{%- if messages[0]["role"] == "system" %}{%- set system_message = messages[0]["content"] %}{%- set loop_messages = messages[1:] %}\n{%- else %}{%- set loop_messages = messages %}{%- endif %}{{- bos_token }}{%- for message in loop_messages %}{%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}{{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}{%- endif %}{%- if message["role"] == "user" %}{%- if loop.last and system_message is defined %}{{- "[INST]" + system_message + "\n\n" }}{%- else %}{{ "[INST]" }}{%- endif %}{%- endif %}{%- if message["content"] is not string %}{%- for chunk in message["content"] %}{%- if chunk["type"] == "text" %}{%- if "content" in chunk %}{{- chunk["content"] }}{%- elif "text" in chunk %}{{- chunk["text"] }}{%- endif %}{%- elif chunk["type"] == "image" %}{{- "[IMG]" }}{%- else %}{{- raise_exception("Unrecognized content type!") }}{%- endif %}{%- endfor %}{%- else %}{{- message["content"] }}{%- endif %}{%- if message["role"] == "user" %}{{- "[/INST]" }}{%- elif message["role"] == "assistant" %}{{- eos_token}}{%- else %}{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}{%- endif %}{%- endfor %}'
|
||||
|
||||
config = read_json(os.path.join(input_dir, "params.json"))
|
||||
patch_size = config["vision_encoder"]["patch_size"]
|
||||
spatial_merge_size = config["vision_encoder"]["spatial_merge_size"]
|
||||
max_image_size = config["vision_encoder"]["max_image_size"]
|
||||
image_processor = PixtralImageProcessorFast(patch_size=patch_size, size={"longest_edge": max_image_size})
|
||||
|
||||
processor = PixtralProcessor(
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
image_token="[IMG]",
|
||||
patch_size=patch_size,
|
||||
chat_template=chat_template,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
)
|
||||
|
||||
# Finally save it
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
help="Location of Mistral weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_position_embeddings",
|
||||
type=int,
|
||||
default=131072,
|
||||
help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings)
|
||||
convert_and_write_processor(args.input_dir, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
553
src/transformers/models/mistral3/modeling_mistral3.py
Normal file
553
src/transformers/models/mistral3/modeling_mistral3.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_mistral3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_mistral3 import Mistral3Config
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "Mistral3Config"
|
||||
|
||||
|
||||
class Mistral3RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Mistral3RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Mistral3PatchMerger(nn.Module):
|
||||
"""
|
||||
Learned merging of spatial_merge_size ** 2 patches
|
||||
"""
|
||||
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
hidden_size = config.vision_config.hidden_size
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
self.patch_size = self.config.vision_config.patch_size
|
||||
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
|
||||
|
||||
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
|
||||
image_sizes = [
|
||||
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
|
||||
]
|
||||
|
||||
tokens_per_image = [h * w for h, w in image_sizes]
|
||||
d = image_features.shape[-1]
|
||||
|
||||
permuted_tensor = []
|
||||
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
||||
# Reshape image_tokens into a 2D grid
|
||||
h, w = image_sizes[image_index]
|
||||
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
||||
grid = torch.nn.functional.unfold(
|
||||
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
|
||||
)
|
||||
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
||||
permuted_tensor.append(grid)
|
||||
|
||||
image_features = torch.cat(permuted_tensor, dim=0)
|
||||
image_features = self.merging_layer(image_features)
|
||||
return image_features
|
||||
|
||||
|
||||
class Mistral3MultiModalProjector(nn.Module):
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__()
|
||||
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
|
||||
self.patch_merger = Mistral3PatchMerger(config)
|
||||
# We have hidden_size * the number of vision feature layers
|
||||
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size * num_feature_layers,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
||||
image_features = self.norm(image_features)
|
||||
image_features = self.patch_merger(image_features, image_sizes)
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mistral3CausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for Mistral3 causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
MISTRAL3_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Mistral3Config`] or [`Mistral3VisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
MISTRAL3_START_DOCSTRING,
|
||||
)
|
||||
class Mistral3PreTrainedModel(PreTrainedModel):
|
||||
config_class = Mistral3Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Mistral3VisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Mistral3 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
MISTRAL3_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Mistral3Processor`] uses
|
||||
[`CLIPImageProcessor`] for processing images).
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
|
||||
MISTRAL3_START_DOCSTRING,
|
||||
)
|
||||
class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = Mistral3MultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
image_sizes: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_sizes (`torch.Tensor`):
|
||||
Tensor containing the image sizes as returned by the processor.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
return image_features
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]
|
||||
286
src/transformers/models/mistral3/modular_mistral3.py
Normal file
286
src/transformers/models/mistral3/modular_mistral3.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
|
||||
from ..mistral.modeling_mistral import MistralRMSNorm
|
||||
from .configuration_mistral3 import Mistral3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Mistral3RMSNorm(MistralRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Mistral3PatchMerger(nn.Module):
|
||||
"""
|
||||
Learned merging of spatial_merge_size ** 2 patches
|
||||
"""
|
||||
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
hidden_size = config.vision_config.hidden_size
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
self.patch_size = self.config.vision_config.patch_size
|
||||
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
|
||||
|
||||
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
|
||||
image_sizes = [
|
||||
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
|
||||
]
|
||||
|
||||
tokens_per_image = [h * w for h, w in image_sizes]
|
||||
d = image_features.shape[-1]
|
||||
|
||||
permuted_tensor = []
|
||||
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
||||
# Reshape image_tokens into a 2D grid
|
||||
h, w = image_sizes[image_index]
|
||||
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
||||
grid = torch.nn.functional.unfold(
|
||||
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
|
||||
)
|
||||
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
||||
permuted_tensor.append(grid)
|
||||
|
||||
image_features = torch.cat(permuted_tensor, dim=0)
|
||||
image_features = self.merging_layer(image_features)
|
||||
return image_features
|
||||
|
||||
|
||||
class Mistral3MultiModalProjector(nn.Module):
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__()
|
||||
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
|
||||
self.patch_merger = Mistral3PatchMerger(config)
|
||||
# We have hidden_size * the number of vision feature layers
|
||||
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size * num_feature_layers,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
||||
image_features = self.norm(image_features)
|
||||
image_features = self.patch_merger(image_features, image_sizes)
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
image_sizes: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_sizes (`torch.Tensor`):
|
||||
Tensor containing the image sizes as returned by the processor.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Mistral3PreTrainedModel", # noqa
|
||||
"Mistral3ForConditionalGeneration",
|
||||
]
|
||||
@@ -128,8 +128,9 @@ def get_resize_output_image_size(
|
||||
|
||||
if ratio > 1:
|
||||
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
||||
height = int(math.ceil(height / ratio))
|
||||
width = int(math.ceil(width / ratio))
|
||||
# Here we use floor to ensure the image is always smaller than the given "longest_edge"
|
||||
height = int(math.floor(height / ratio))
|
||||
width = int(math.floor(width / ratio))
|
||||
|
||||
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
||||
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
||||
|
||||
@@ -64,6 +64,8 @@ class PixtralProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
Patch size from the vision tower.
|
||||
spatial_merge_size (`int`, *optional*, defaults to 1):
|
||||
The downsampling factor for the spatial merge operation.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"[IMG]"`):
|
||||
@@ -78,6 +80,7 @@ class PixtralProcessor(ProcessorMixin):
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"spatial_merge_size",
|
||||
"image_token",
|
||||
"image_break_token",
|
||||
"image_end_token",
|
||||
@@ -90,6 +93,7 @@ class PixtralProcessor(ProcessorMixin):
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size: int = 16,
|
||||
spatial_merge_size: int = 1,
|
||||
chat_template=None,
|
||||
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
image_break_token="[IMG_BREAK]",
|
||||
@@ -97,6 +101,7 @@ class PixtralProcessor(ProcessorMixin):
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.image_token = image_token
|
||||
self.image_break_token = image_break_token
|
||||
self.image_end_token = image_end_token
|
||||
@@ -187,8 +192,8 @@ class PixtralProcessor(ProcessorMixin):
|
||||
for sample in text:
|
||||
while self.image_token in sample:
|
||||
height, width = next(image_sizes)
|
||||
num_height_tokens = height // self.patch_size
|
||||
num_width_tokens = width // self.patch_size
|
||||
num_height_tokens = height // (self.patch_size * self.spatial_merge_size)
|
||||
num_width_tokens = width // (self.patch_size * self.spatial_merge_size)
|
||||
replace_tokens = [
|
||||
[self.image_token] * num_width_tokens + [self.image_break_token]
|
||||
] * num_height_tokens
|
||||
|
||||
@@ -6392,6 +6392,20 @@ class MistralPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Mistral3ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Mistral3PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MixtralForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ VLM_CLASS_NAMES = [
|
||||
"qwen2_5_vl",
|
||||
"ayavision",
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
]
|
||||
|
||||
|
||||
|
||||
0
tests/models/mistral3/__init__.py
Normal file
0
tests/models/mistral3/__init__.py
Normal file
482
tests/models/mistral3/test_modeling_mistral3.py
Normal file
482
tests/models/mistral3/test_modeling_mistral3.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch GotOcr2 model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Mistral3Config,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Mistral3ForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
|
||||
class Mistral3VisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
seq_length=7,
|
||||
image_seq_length=4,
|
||||
vision_feature_layer=-1,
|
||||
ignore_index=-100,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_index=1,
|
||||
num_channels=3,
|
||||
image_size=30,
|
||||
model_type="mistral3",
|
||||
is_training=True,
|
||||
text_config={
|
||||
"model_type": "mistral",
|
||||
"vocab_size": 99,
|
||||
"attention_dropout": 0.0,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 32,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 37,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 4,
|
||||
"num_hidden_layers": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 1000000000.0,
|
||||
"sliding_window": None,
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"pad_token_id": 0,
|
||||
},
|
||||
vision_config={
|
||||
"model_type": "pixtral",
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"image_size": 30,
|
||||
"patch_size": 6,
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu",
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.image_token_index = image_token_index
|
||||
self.model_type = model_type
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.batch_size = batch_size
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.is_training = is_training
|
||||
self.image_seq_length = image_seq_length
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.seq_length = seq_length + self.image_seq_length
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
|
||||
def get_config(self):
|
||||
return Mistral3Config(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
model_type=self.model_type,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
image_token_index=self.image_token_index,
|
||||
image_seq_length=self.image_seq_length,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
image_sizes = torch.tensor(
|
||||
[[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
# input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_index] = self.pad_token_id
|
||||
input_ids[:, : self.image_seq_length] = self.image_token_index
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"image_sizes": image_sizes,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
model = Mistral3ForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
config.torch_dtype = torch.float16
|
||||
model = Mistral3ForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"image-text-to-text": Mistral3ForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
_is_composite = True
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Mistral3VisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Mistral3Config, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
# overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077
|
||||
# TODO: avoid overwritten once there is a better fix for #36077
|
||||
def check_config_can_be_init_without_params():
|
||||
config = self.config_tester.config_class()
|
||||
self.config_tester.parent.assertIsNotNone(config)
|
||||
|
||||
self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||
def test_flash_attn_2_from_config(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mistral3IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_mistral3_integration_generate_text_only(self):
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Write a haiku"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "Sure, here's a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace."
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_mistral3_integration_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_mistral3_integration_batched_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
{"type": "text", "text": "Write a haiku for this image"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
],
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's mirror gleams,\nWhispering pines"
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_mistral3_integration_batched_generate_multi_image(self):
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
self.model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
# Prepare inputs
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
{"type": "text", "text": "Write a haiku for this image"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "These images depict two different landmarks. Can you identify them?",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(model.device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n"
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = "These images depict two different landmarks. Can you identify them?Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
293
tests/models/mistral3/test_processor_mistral3.py
Normal file
293
tests/models/mistral3/test_processor_mistral3.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import PixtralProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@require_vision
|
||||
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
||||
|
||||
processor_class = PixtralProcessor
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
|
||||
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
|
||||
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
processor = PixtralProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_chat_template(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
def test_image_token_filling(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
# Important to check with non square image
|
||||
image = torch.randint(0, 2, (3, 500, 316))
|
||||
expected_image_tokens = 198
|
||||
image_token_index = 10
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor(
|
||||
text=[processor.apply_chat_template(messages)],
|
||||
images=[image],
|
||||
return_tensors="pt",
|
||||
)
|
||||
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||
self.assertEqual(expected_image_tokens, image_tokens)
|
||||
|
||||
def test_processor_with_single_image(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
processor.patch_size = 6
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing inputs as a single list
|
||||
inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test as nested single list
|
||||
inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_with_multiple_images_single_list(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
processor.patch_size = 6
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in as a nested list
|
||||
inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_url["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_with_multiple_images_multiple_lists(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = [
|
||||
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
|
||||
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||
]
|
||||
processor.tokenizer.pad_token = "</s>"
|
||||
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
processor.patch_size = 6
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing as a single flat list
|
||||
inputs_image = processor(
|
||||
text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True
|
||||
)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_returns_full_length_batches(self):
|
||||
# to avoid https://github.com/huggingface/transformers/issues/34204
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = [
|
||||
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||
] * 5
|
||||
processor.tokenizer.pad_token = "</s>"
|
||||
image_inputs = [[self.image_0]] * 5
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
processor.patch_size = 6
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 5)
|
||||
@@ -109,8 +109,8 @@ class PixtralImageProcessingTester:
|
||||
|
||||
ratio = max(height / max_height, width / max_width)
|
||||
if ratio > 1:
|
||||
height = int(np.ceil(height / ratio))
|
||||
width = int(np.ceil(width / ratio))
|
||||
height = int(np.floor(height / ratio))
|
||||
width = int(np.floor(width / ratio))
|
||||
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
num_height_tokens = (height - 1) // patch_height + 1
|
||||
|
||||
Reference in New Issue
Block a user