Add Mistral3 (#36790)
* initial start * style and dummies * Create convert_mistral3_weights_to_hf.py * update * typo * typo * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * up * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * update * update * Update image_processing_mistral3.py * Update convert_mistral3_weights_to_hf.py * fix patch merger * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * up * update modular to fit * style * Update convert_mistral3_weights_to_hf.py * typo * Update modular_mistral3.py * simplify a lot all shape shenanigans * simplify * add working test processor * Add partially working common modeling tests * All tests working and remove mistral3 image processors * add docs and fixup * fix inference with image size >1540 * 🚨fix test image proc pixtral * Remove vision_feature_select_strategy * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * clean * fix test checkpoints * Update test_modeling_mistral3.py * Update test_modeling_mistral3.py * style * Use Pixtral processor * up * finish cleaning processor to use pixtral directly * Update __init__.py * Update processing_pixtral.py * doc * Update __init__.py * Update mistral3.md * Update _toctree.yml --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com>
This commit is contained in:
@@ -529,6 +529,8 @@
|
|||||||
title: MegatronGPT2
|
title: MegatronGPT2
|
||||||
- local: model_doc/mistral
|
- local: model_doc/mistral
|
||||||
title: Mistral
|
title: Mistral
|
||||||
|
- local: model_doc/mistral3
|
||||||
|
title: Mistral3
|
||||||
- local: model_doc/mixtral
|
- local: model_doc/mixtral
|
||||||
title: Mixtral
|
title: Mixtral
|
||||||
- local: model_doc/mluke
|
- local: model_doc/mluke
|
||||||
|
|||||||
234
docs/source/en/model_doc/mistral3.md
Normal file
234
docs/source/en/model_doc/mistral3.md
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Mistral3
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Building upon Mistral Small 3 (2501), Mistral Small 3.1 (2503) adds state-of-the-art vision understanding and enhances long context capabilities up to 128k tokens without compromising text performance. With 24 billion parameters, this model achieves top-tier capabilities in both text and vision tasks.
|
||||||
|
|
||||||
|
It is ideal for:
|
||||||
|
- Fast-response conversational agents.
|
||||||
|
- Low-latency function calling.
|
||||||
|
- Subject matter experts via fine-tuning.
|
||||||
|
- Local inference for hobbyists and organizations handling sensitive data.
|
||||||
|
- Programming and math reasoning.
|
||||||
|
- Long document understanding.
|
||||||
|
- Visual understanding.
|
||||||
|
|
||||||
|
This model was contributed by [cyrilvallez](https://huggingface.co/cyrilvallez) and [yonigozlan](https://huggingface.co/yonigozlan).
|
||||||
|
|
||||||
|
The original code can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/pixtral.py) and [here](https://github.com/mistralai/mistral-common).
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
|
||||||
|
### Inference with Pipeline
|
||||||
|
|
||||||
|
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `Mistral3` models in just a few lines of code:
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {
|
||||||
|
... "type": "image",
|
||||||
|
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||||
|
... },
|
||||||
|
... {"type": "text", "text": "Describe this image."},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16)
|
||||||
|
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
||||||
|
>>> outputs[0]["generated_text"]
|
||||||
|
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
|
||||||
|
```
|
||||||
|
### Inference on a single image
|
||||||
|
|
||||||
|
This example demonstrates how to perform inference on a single image with the Mistral3 models using chat templates.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||||
|
... {"type": "text", "text": "Describe this image"},
|
||||||
|
... ],
|
||||||
|
... }
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||||
|
|
||||||
|
>>> decoded_output
|
||||||
|
"The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Text-only generation
|
||||||
|
This example shows how to generate text using the Mistral3 model without providing any image input.
|
||||||
|
|
||||||
|
|
||||||
|
````python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = ".mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, always end your accurate response with an ASCII drawing of a cat."
|
||||||
|
>>> user_prompt = "Give me 5 non-formal ways to say 'See you later' in French."
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
... {"role": "user", "content": user_prompt},
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
>>> inputs = processor(text=text, return_tensors="pt").to(0, dtype=torch.float16)
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
|
||||||
|
>>> decoded_output = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
>>> print(decoded_output)
|
||||||
|
"1. À plus tard!
|
||||||
|
2. Salut, à plus!
|
||||||
|
3. À toute!
|
||||||
|
4. À la prochaine!
|
||||||
|
5. Je me casse, à plus!
|
||||||
|
|
||||||
|
```
|
||||||
|
/\_/\
|
||||||
|
( o.o )
|
||||||
|
> ^ <
|
||||||
|
```"
|
||||||
|
````
|
||||||
|
|
||||||
|
### Batched image and text inputs
|
||||||
|
Mistral3 models also support batched image and text inputs.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
... {"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
... {"type": "text", "text": "Describe this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... ]
|
||||||
|
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||||
|
|
||||||
|
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
>>> decoded_outputs
|
||||||
|
["Write a haiku for this imageCalm waters reflect\nWhispers of the forest's breath\nPeace on wooden path"
|
||||||
|
, "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batched multi-image input and quantization with BitsAndBytes
|
||||||
|
This implementation of the Mistral3 models supports batched text-images inputs with different number of images for each text.
|
||||||
|
This example also how to use `BitsAndBytes` to load the model in 4bit quantization.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
... model_checkpoint, quantization_config=quantization_config
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
... {"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||||
|
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||||
|
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
>>> ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||||
|
|
||||||
|
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
>>> decoded_outputs
|
||||||
|
["Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n", "These images depict two different landmarks. Can you identify them? Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Mistral3Config
|
||||||
|
|
||||||
|
[[autodoc]] Mistral3Config
|
||||||
|
|
||||||
|
|
||||||
|
## Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
[[autodoc]] Mistral3ForConditionalGeneration
|
||||||
|
- forward
|
||||||
@@ -613,6 +613,7 @@ _import_structure = {
|
|||||||
],
|
],
|
||||||
"models.mimi": ["MimiConfig"],
|
"models.mimi": ["MimiConfig"],
|
||||||
"models.mistral": ["MistralConfig"],
|
"models.mistral": ["MistralConfig"],
|
||||||
|
"models.mistral3": ["Mistral3Config"],
|
||||||
"models.mixtral": ["MixtralConfig"],
|
"models.mixtral": ["MixtralConfig"],
|
||||||
"models.mllama": [
|
"models.mllama": [
|
||||||
"MllamaConfig",
|
"MllamaConfig",
|
||||||
@@ -2940,6 +2941,12 @@ else:
|
|||||||
"MistralPreTrainedModel",
|
"MistralPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.mistral3"].extend(
|
||||||
|
[
|
||||||
|
"Mistral3ForConditionalGeneration",
|
||||||
|
"Mistral3PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.mixtral"].extend(
|
_import_structure["models.mixtral"].extend(
|
||||||
[
|
[
|
||||||
"MixtralForCausalLM",
|
"MixtralForCausalLM",
|
||||||
@@ -5788,6 +5795,7 @@ if TYPE_CHECKING:
|
|||||||
MimiConfig,
|
MimiConfig,
|
||||||
)
|
)
|
||||||
from .models.mistral import MistralConfig
|
from .models.mistral import MistralConfig
|
||||||
|
from .models.mistral3 import Mistral3Config
|
||||||
from .models.mixtral import MixtralConfig
|
from .models.mixtral import MixtralConfig
|
||||||
from .models.mllama import (
|
from .models.mllama import (
|
||||||
MllamaConfig,
|
MllamaConfig,
|
||||||
@@ -7844,6 +7852,10 @@ if TYPE_CHECKING:
|
|||||||
MistralModel,
|
MistralModel,
|
||||||
MistralPreTrainedModel,
|
MistralPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.mistral3 import (
|
||||||
|
Mistral3ForConditionalGeneration,
|
||||||
|
Mistral3PreTrainedModel,
|
||||||
|
)
|
||||||
from .models.mixtral import (
|
from .models.mixtral import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
MixtralForQuestionAnswering,
|
MixtralForQuestionAnswering,
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ from . import (
|
|||||||
mgp_str,
|
mgp_str,
|
||||||
mimi,
|
mimi,
|
||||||
mistral,
|
mistral,
|
||||||
|
mistral3,
|
||||||
mixtral,
|
mixtral,
|
||||||
mllama,
|
mllama,
|
||||||
mluke,
|
mluke,
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("mgp-str", "MgpstrConfig"),
|
("mgp-str", "MgpstrConfig"),
|
||||||
("mimi", "MimiConfig"),
|
("mimi", "MimiConfig"),
|
||||||
("mistral", "MistralConfig"),
|
("mistral", "MistralConfig"),
|
||||||
|
("mistral3", "Mistral3Config"),
|
||||||
("mixtral", "MixtralConfig"),
|
("mixtral", "MixtralConfig"),
|
||||||
("mllama", "MllamaConfig"),
|
("mllama", "MllamaConfig"),
|
||||||
("mobilebert", "MobileBertConfig"),
|
("mobilebert", "MobileBertConfig"),
|
||||||
@@ -537,6 +538,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("mgp-str", "MGP-STR"),
|
("mgp-str", "MGP-STR"),
|
||||||
("mimi", "Mimi"),
|
("mimi", "Mimi"),
|
||||||
("mistral", "Mistral"),
|
("mistral", "Mistral"),
|
||||||
|
("mistral3", "Mistral3"),
|
||||||
("mixtral", "Mixtral"),
|
("mixtral", "Mixtral"),
|
||||||
("mllama", "Mllama"),
|
("mllama", "Mllama"),
|
||||||
("mluke", "mLUKE"),
|
("mluke", "mLUKE"),
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ else:
|
|||||||
("mask2former", ("Mask2FormerImageProcessor",)),
|
("mask2former", ("Mask2FormerImageProcessor",)),
|
||||||
("maskformer", ("MaskFormerImageProcessor",)),
|
("maskformer", ("MaskFormerImageProcessor",)),
|
||||||
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||||
|
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||||
("mllama", ("MllamaImageProcessor",)),
|
("mllama", ("MllamaImageProcessor",)),
|
||||||
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
|
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
|
||||||
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
|
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
|
||||||
|
|||||||
@@ -361,6 +361,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
|||||||
("mamba2", "Mamba2ForCausalLM"),
|
("mamba2", "Mamba2ForCausalLM"),
|
||||||
("mega", "MegaForMaskedLM"),
|
("mega", "MegaForMaskedLM"),
|
||||||
("megatron-bert", "MegatronBertForPreTraining"),
|
("megatron-bert", "MegatronBertForPreTraining"),
|
||||||
|
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||||
("mllama", "MllamaForConditionalGeneration"),
|
("mllama", "MllamaForConditionalGeneration"),
|
||||||
("mobilebert", "MobileBertForPreTraining"),
|
("mobilebert", "MobileBertForPreTraining"),
|
||||||
("mpnet", "MPNetForMaskedLM"),
|
("mpnet", "MPNetForMaskedLM"),
|
||||||
@@ -802,6 +803,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|||||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||||
|
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||||
("mllama", "MllamaForConditionalGeneration"),
|
("mllama", "MllamaForConditionalGeneration"),
|
||||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||||
@@ -839,6 +841,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
|||||||
("llava", "LlavaForConditionalGeneration"),
|
("llava", "LlavaForConditionalGeneration"),
|
||||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||||
|
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||||
("mllama", "MllamaForConditionalGeneration"),
|
("mllama", "MllamaForConditionalGeneration"),
|
||||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("markuplm", "MarkupLMProcessor"),
|
("markuplm", "MarkupLMProcessor"),
|
||||||
("mctct", "MCTCTProcessor"),
|
("mctct", "MCTCTProcessor"),
|
||||||
("mgp-str", "MgpstrProcessor"),
|
("mgp-str", "MgpstrProcessor"),
|
||||||
|
("mistral3", "PixtralProcessor"),
|
||||||
("mllama", "MllamaProcessor"),
|
("mllama", "MllamaProcessor"),
|
||||||
("moonshine", "Wav2Vec2Processor"),
|
("moonshine", "Wav2Vec2Processor"),
|
||||||
("oneformer", "OneFormerProcessor"),
|
("oneformer", "OneFormerProcessor"),
|
||||||
|
|||||||
28
src/transformers/models/mistral3/__init__.py
Normal file
28
src/transformers/models/mistral3/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_mistral3 import *
|
||||||
|
from .modeling_mistral3 import *
|
||||||
|
from .processing_mistral3 import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
137
src/transformers/models/mistral3/configuration_mistral3.py
Normal file
137
src/transformers/models/mistral3/configuration_mistral3.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Mistral3ForConditionalGeneration`]. It is used to instantiate an
|
||||||
|
Mistral3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of
|
||||||
|
[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `PixtralVisionConfig`):
|
||||||
|
The config object or dictionary of the vision backbone.
|
||||||
|
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MistralConfig`):
|
||||||
|
The config object or dictionary of the text backbone.
|
||||||
|
image_token_index (`int`, *optional*, defaults to 10):
|
||||||
|
The image token index to encode the image prompt.
|
||||||
|
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
The activation function used by the multimodal projector.
|
||||||
|
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -1):
|
||||||
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
|
vision features.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
|
spatial_merge_size (`int`, *optional*, defaults to 2):
|
||||||
|
The downsampling factor for the spatial merge operation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Mistral3ForConditionalGeneration, Mistral3Config, PixtralVisionConfig, MistralConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Pixtral-vision config
|
||||||
|
>>> vision_config = PixtralVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Mistral config
|
||||||
|
>>> text_config = MistralConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Mistral3 configuration
|
||||||
|
>>> configuration = Mistral3Config(vision_config, text_config)
|
||||||
|
|
||||||
|
>>> # Initializing a model from the mistral3.1 configuration
|
||||||
|
>>> model = Mistral3ForConditionalGeneration(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "mistral3"
|
||||||
|
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||||
|
is_composition = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config=None,
|
||||||
|
text_config=None,
|
||||||
|
image_token_index=10,
|
||||||
|
projector_hidden_act="gelu",
|
||||||
|
vision_feature_layer=-1,
|
||||||
|
multimodal_projector_bias=False,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.image_token_index = image_token_index
|
||||||
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
|
||||||
|
self.vision_feature_layer = vision_feature_layer
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
vision_config["model_type"] = vision_config["model_type"] if "model_type" in vision_config else "pixtral"
|
||||||
|
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||||
|
elif vision_config is None:
|
||||||
|
vision_config = CONFIG_MAPPING["pixtral"](
|
||||||
|
intermediate_size=4096,
|
||||||
|
hidden_size=1024,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=1540,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
vocab_size=32000,
|
||||||
|
head_dim=64,
|
||||||
|
hidden_act="gelu",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if isinstance(text_config, dict):
|
||||||
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
|
||||||
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||||
|
elif text_config is None:
|
||||||
|
text_config = CONFIG_MAPPING["mistral"](
|
||||||
|
attention_dropout=0.0,
|
||||||
|
head_dim=128,
|
||||||
|
hidden_act="silu",
|
||||||
|
hidden_size=5120,
|
||||||
|
initializer_range=0.02,
|
||||||
|
intermediate_size=32768,
|
||||||
|
max_position_embeddings=131072,
|
||||||
|
model_type="mistral",
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
rope_theta=1000000000.0,
|
||||||
|
sliding_window=None,
|
||||||
|
use_cache=True,
|
||||||
|
vocab_size=131072,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Mistral3Config"]
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
Mistral3Config,
|
||||||
|
Mistral3ForConditionalGeneration,
|
||||||
|
MistralConfig,
|
||||||
|
PixtralImageProcessorFast,
|
||||||
|
PixtralProcessor,
|
||||||
|
PixtralVisionConfig,
|
||||||
|
)
|
||||||
|
from transformers.integrations.mistral import convert_tekken_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
STATE_DICT_MAPPING = {
|
||||||
|
# Text model keys
|
||||||
|
r"^output.weight": r"language_model.lm_head.weight",
|
||||||
|
r"^norm.weight": r"language_model.model.norm.weight",
|
||||||
|
r"^tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
|
||||||
|
r"^layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
|
||||||
|
r"^layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||||
|
r"^layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight",
|
||||||
|
r"^layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
|
||||||
|
r"^layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
|
||||||
|
r"^layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
|
||||||
|
|
||||||
|
# Vision model keys
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight",
|
||||||
|
r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight",
|
||||||
|
r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"vision_tower.transformer.layers.\1.attention.\2_proj.weight",
|
||||||
|
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
|
||||||
|
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
|
||||||
|
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
|
||||||
|
r"^vision_language_adapter.w_in": r"multi_modal_projector.linear_1",
|
||||||
|
r"^vision_language_adapter.w_out": r"multi_modal_projector.linear_2",
|
||||||
|
r"^vision_encoder.ln_pre.weight": r"vision_tower.ln_pre.weight",
|
||||||
|
r"^vision_encoder.patch_conv.weight": r"vision_tower.patch_conv.weight",
|
||||||
|
r"^patch_merger.merging_layer.weight": r"multi_modal_projector.patch_merger.merging_layer.weight",
|
||||||
|
r"^pre_mm_projector_norm.weight": r"multi_modal_projector.norm.weight",
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def map_old_key_to_new(old_key):
|
||||||
|
"""Map of a key of the original state dict to the equivalent key in HF format"""
|
||||||
|
for pattern, replacement in STATE_DICT_MAPPING.items():
|
||||||
|
new_key, n_replace = re.subn(pattern, replacement, old_key)
|
||||||
|
# Early exit of the loop
|
||||||
|
if n_replace > 0:
|
||||||
|
return new_key
|
||||||
|
|
||||||
|
raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")
|
||||||
|
|
||||||
|
|
||||||
|
def read_json(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def permute_for_rope(tensor, n_heads, dim1, dim2):
|
||||||
|
"""Permute the weights for the ROPE formulation."""
|
||||||
|
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||||
|
tensor = tensor.transpose(1, 2)
|
||||||
|
tensor = tensor.reshape(dim1, dim2)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(original_state_dict: dict, config: MistralConfig):
|
||||||
|
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
|
||||||
|
new_dict = {}
|
||||||
|
|
||||||
|
for old_key, tensor in original_state_dict.items():
|
||||||
|
new_key = map_old_key_to_new(old_key)
|
||||||
|
|
||||||
|
if "vision" in old_key:
|
||||||
|
num_attention_heads = config.vision_config.num_attention_heads
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
hidden_size = config.vision_config.hidden_size
|
||||||
|
head_dim = config.vision_config.head_dim
|
||||||
|
key_value_dim = head_dim * num_attention_heads
|
||||||
|
query_dim = head_dim * num_attention_heads
|
||||||
|
else:
|
||||||
|
num_attention_heads = config.text_config.num_attention_heads
|
||||||
|
hidden_size = config.text_config.hidden_size
|
||||||
|
head_dim = config.text_config.head_dim
|
||||||
|
num_key_value_heads = config.text_config.num_key_value_heads
|
||||||
|
key_value_dim = head_dim * num_key_value_heads
|
||||||
|
query_dim = head_dim * num_attention_heads
|
||||||
|
|
||||||
|
if "q_proj" in new_key:
|
||||||
|
tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size)
|
||||||
|
elif "k_proj" in new_key:
|
||||||
|
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
|
||||||
|
|
||||||
|
new_dict[new_key] = tensor
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_config(original_config: dict, max_position_embeddings: int = 131072):
|
||||||
|
original_vision_config = original_config.pop("vision_encoder")
|
||||||
|
original_text_config = original_config
|
||||||
|
|
||||||
|
# Text config
|
||||||
|
text_key_mapping = {
|
||||||
|
"hidden_size": "dim",
|
||||||
|
"num_hidden_layers": "n_layers",
|
||||||
|
"intermediate_size": "hidden_dim",
|
||||||
|
"num_attention_heads": "n_heads",
|
||||||
|
"num_key_value_heads": "n_kv_heads",
|
||||||
|
"rms_norm_eps": "norm_eps",
|
||||||
|
}
|
||||||
|
similar_text_keys_to_keep = [
|
||||||
|
"head_dim",
|
||||||
|
"vocab_size",
|
||||||
|
"rope_theta",
|
||||||
|
]
|
||||||
|
new_text_config_kwargs = {k: original_text_config[v] for k, v in text_key_mapping.items()}
|
||||||
|
new_text_config_kwargs.update({k: v for k, v in original_text_config.items() if k in similar_text_keys_to_keep})
|
||||||
|
# These are not always defined depending on `params.json`
|
||||||
|
new_text_config_kwargs["sliding_window"] = original_text_config.get("sliding_window", None)
|
||||||
|
new_text_config_kwargs["max_position_embeddings"] = original_text_config.get(
|
||||||
|
"max_seq_len", max_position_embeddings
|
||||||
|
)
|
||||||
|
# This may sometimes be a string in `params.json`
|
||||||
|
if new_text_config_kwargs["sliding_window"] is not None:
|
||||||
|
new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"])
|
||||||
|
new_text_config = MistralConfig(**new_text_config_kwargs)
|
||||||
|
|
||||||
|
# Vision config
|
||||||
|
new_vision_config = original_vision_config
|
||||||
|
adapter_bias = new_vision_config.pop("adapter_bias", False)
|
||||||
|
_ = new_vision_config.pop("mm_projector_id", None)
|
||||||
|
_ = new_vision_config.pop("add_pre_mm_projector_layer_norm", None)
|
||||||
|
spatial_merge_size = new_vision_config.pop("spatial_merge_size")
|
||||||
|
image_token_id = new_vision_config.pop("image_token_id", 10)
|
||||||
|
_ = new_vision_config.pop("image_break_token_id", 12)
|
||||||
|
_ = new_vision_config.pop("image_end_token_id", 13)
|
||||||
|
_ = new_vision_config.pop("max_image_size")
|
||||||
|
new_vision_config = PixtralVisionConfig(**new_vision_config)
|
||||||
|
|
||||||
|
new_config = Mistral3Config(
|
||||||
|
vision_config=new_vision_config,
|
||||||
|
text_config=new_text_config,
|
||||||
|
multimodal_projector_bias=adapter_bias,
|
||||||
|
image_token_index=image_token_id,
|
||||||
|
spatial_merge_size=spatial_merge_size,
|
||||||
|
vision_feature_layer=-1,
|
||||||
|
)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
|
def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int):
|
||||||
|
"""Convert the model and save it (this implicitly save the config as well)."""
|
||||||
|
params = read_json(os.path.join(input_dir, "params.json"))
|
||||||
|
config = convert_config(params, max_position_embeddings)
|
||||||
|
|
||||||
|
full_state_dict = {}
|
||||||
|
# The model may be split between different files, but a single nn.Module is always fully present in a single file
|
||||||
|
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
|
||||||
|
for shard_file in shards:
|
||||||
|
original_state_dict = load_file(os.path.join(input_dir, shard_file))
|
||||||
|
new_dict = convert_state_dict(original_state_dict, config)
|
||||||
|
full_state_dict.update(new_dict)
|
||||||
|
|
||||||
|
# Load weights into model and resave them
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = Mistral3ForConditionalGeneration(config)
|
||||||
|
model.load_state_dict(full_state_dict, strict=True, assign=True)
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_and_write_processor(input_dir: str, output_dir: str):
|
||||||
|
"""Convert the tokenizer and save it."""
|
||||||
|
tokenizer_file = os.path.join(input_dir, "tekken.json")
|
||||||
|
tokenizer = convert_tekken_tokenizer(tokenizer_file)
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||||
|
chat_template = '{%- if messages[0]["role"] == "system" %}{%- set system_message = messages[0]["content"] %}{%- set loop_messages = messages[1:] %}\n{%- else %}{%- set loop_messages = messages %}{%- endif %}{{- bos_token }}{%- for message in loop_messages %}{%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}{{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}{%- endif %}{%- if message["role"] == "user" %}{%- if loop.last and system_message is defined %}{{- "[INST]" + system_message + "\n\n" }}{%- else %}{{ "[INST]" }}{%- endif %}{%- endif %}{%- if message["content"] is not string %}{%- for chunk in message["content"] %}{%- if chunk["type"] == "text" %}{%- if "content" in chunk %}{{- chunk["content"] }}{%- elif "text" in chunk %}{{- chunk["text"] }}{%- endif %}{%- elif chunk["type"] == "image" %}{{- "[IMG]" }}{%- else %}{{- raise_exception("Unrecognized content type!") }}{%- endif %}{%- endfor %}{%- else %}{{- message["content"] }}{%- endif %}{%- if message["role"] == "user" %}{{- "[/INST]" }}{%- elif message["role"] == "assistant" %}{{- eos_token}}{%- else %}{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}{%- endif %}{%- endfor %}'
|
||||||
|
|
||||||
|
config = read_json(os.path.join(input_dir, "params.json"))
|
||||||
|
patch_size = config["vision_encoder"]["patch_size"]
|
||||||
|
spatial_merge_size = config["vision_encoder"]["spatial_merge_size"]
|
||||||
|
max_image_size = config["vision_encoder"]["max_image_size"]
|
||||||
|
image_processor = PixtralImageProcessorFast(patch_size=patch_size, size={"longest_edge": max_image_size})
|
||||||
|
|
||||||
|
processor = PixtralProcessor(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
image_processor=image_processor,
|
||||||
|
image_token="[IMG]",
|
||||||
|
patch_size=patch_size,
|
||||||
|
chat_template=chat_template,
|
||||||
|
spatial_merge_size=spatial_merge_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally save it
|
||||||
|
processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"input_dir",
|
||||||
|
help="Location of Mistral weights, which contains tokenizer.model and model folders",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"output_dir",
|
||||||
|
help="Location to write HF model and tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_position_embeddings",
|
||||||
|
type=int,
|
||||||
|
default=131072,
|
||||||
|
help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings)
|
||||||
|
convert_and_write_processor(args.input_dir, args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
553
src/transformers/models/mistral3/modeling_mistral3.py
Normal file
553
src/transformers/models/mistral3/modeling_mistral3.py
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_mistral3.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...generation import GenerationMixin
|
||||||
|
from ...modeling_outputs import ModelOutput
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
|
from ..auto import AutoModel, AutoModelForCausalLM
|
||||||
|
from .configuration_mistral3 import Mistral3Config
|
||||||
|
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "Mistral3Config"
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Mistral3RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3PatchMerger(nn.Module):
|
||||||
|
"""
|
||||||
|
Learned merging of spatial_merge_size ** 2 patches
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Mistral3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
hidden_size = config.vision_config.hidden_size
|
||||||
|
self.spatial_merge_size = config.spatial_merge_size
|
||||||
|
self.patch_size = self.config.vision_config.patch_size
|
||||||
|
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
|
||||||
|
image_sizes = [
|
||||||
|
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
tokens_per_image = [h * w for h, w in image_sizes]
|
||||||
|
d = image_features.shape[-1]
|
||||||
|
|
||||||
|
permuted_tensor = []
|
||||||
|
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
||||||
|
# Reshape image_tokens into a 2D grid
|
||||||
|
h, w = image_sizes[image_index]
|
||||||
|
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
grid = torch.nn.functional.unfold(
|
||||||
|
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
|
||||||
|
)
|
||||||
|
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
||||||
|
permuted_tensor.append(grid)
|
||||||
|
|
||||||
|
image_features = torch.cat(permuted_tensor, dim=0)
|
||||||
|
image_features = self.merging_layer(image_features)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3MultiModalProjector(nn.Module):
|
||||||
|
def __init__(self, config: Mistral3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
|
||||||
|
self.patch_merger = Mistral3PatchMerger(config)
|
||||||
|
# We have hidden_size * the number of vision feature layers
|
||||||
|
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
|
config.vision_config.hidden_size * num_feature_layers,
|
||||||
|
config.text_config.hidden_size,
|
||||||
|
bias=config.multimodal_projector_bias,
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
||||||
|
image_features = self.norm(image_features)
|
||||||
|
image_features = self.patch_merger(image_features, image_sizes)
|
||||||
|
hidden_states = self.linear_1(image_features)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mistral3CausalLMOutputWithPast(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for Mistral3 causal language model (or autoregressive) outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
|
Language modeling loss (for next-token prediction).
|
||||||
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||||
|
`past_key_values` input) to speed up sequential decoding.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
|
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
|
||||||
|
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
MISTRAL3_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`Mistral3Config`] or [`Mistral3VisionConfig`]):
|
||||||
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||||
|
load the weights associated with the model, only the configuration. Check out the
|
||||||
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||||
|
MISTRAL3_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Mistral3PreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = Mistral3Config
|
||||||
|
base_model_prefix = "model"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["Mistral3VisionAttention"]
|
||||||
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
# important: this ported version of Mistral3 isn't meant for training from scratch - only
|
||||||
|
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||||
|
# https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose
|
||||||
|
std = (
|
||||||
|
self.config.initializer_range
|
||||||
|
if hasattr(self.config, "initializer_range")
|
||||||
|
else self.config.text_config.initializer_range
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(module, "class_embedding"):
|
||||||
|
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||||
|
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
|
MISTRAL3_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||||
|
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||||
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Mistral3Processor`] uses
|
||||||
|
[`CLIPImageProcessor`] for processing images).
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||||
|
`past_key_values`).
|
||||||
|
|
||||||
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||||
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||||
|
information on the default strategy.
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||||
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||||
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||||
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
|
model's internal embedding lookup matrix.
|
||||||
|
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
|
||||||
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
|
vision features.
|
||||||
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||||
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
|
Can be one of `"default"` or `"full"`.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
|
`past_key_values`).
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
|
||||||
|
MISTRAL3_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
|
||||||
|
def __init__(self, config: Mistral3Config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||||
|
|
||||||
|
self.multi_modal_projector = Mistral3MultiModalProjector(config)
|
||||||
|
self.vocab_size = config.text_config.vocab_size
|
||||||
|
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||||
|
|
||||||
|
if self.language_model._tied_weights_keys is not None:
|
||||||
|
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||||
|
|
||||||
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.language_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.language_model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.language_model.get_output_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.language_model.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.language_model.set_decoder(decoder)
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.language_model.get_decoder()
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
vision_feature_layer: Union[int, List[int]],
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
vision_feature_layer (`Union[int, List[int]]`):
|
||||||
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
|
vision features.
|
||||||
|
image_sizes (`torch.Tensor`):
|
||||||
|
Tensor containing the image sizes as returned by the processor.
|
||||||
|
Returns:
|
||||||
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
|
"""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||||
|
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||||
|
# If we have one vision feature layer, return the corresponding hidden states,
|
||||||
|
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||||
|
if isinstance(vision_feature_layer, int):
|
||||||
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||||
|
else:
|
||||||
|
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||||
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=Mistral3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
image_sizes: torch.Tensor = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
|
||||||
|
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = outputs[0]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||||
|
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||||
|
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Mistral3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
pixel_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
cache_position=None,
|
||||||
|
logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||||
|
|
||||||
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_position[0] == 0:
|
||||||
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
|
# Otherwise we need pixel values to be passed to model
|
||||||
|
model_inputs["pixel_values"] = pixel_values
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]
|
||||||
286
src/transformers/models/mistral3/modular_mistral3.py
Normal file
286
src/transformers/models/mistral3/modular_mistral3.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...utils import is_torchdynamo_compiling, logging
|
||||||
|
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
|
||||||
|
from ..mistral.modeling_mistral import MistralRMSNorm
|
||||||
|
from .configuration_mistral3 import Mistral3Config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3RMSNorm(MistralRMSNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3PatchMerger(nn.Module):
|
||||||
|
"""
|
||||||
|
Learned merging of spatial_merge_size ** 2 patches
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Mistral3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
hidden_size = config.vision_config.hidden_size
|
||||||
|
self.spatial_merge_size = config.spatial_merge_size
|
||||||
|
self.patch_size = self.config.vision_config.patch_size
|
||||||
|
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
|
||||||
|
image_sizes = [
|
||||||
|
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
tokens_per_image = [h * w for h, w in image_sizes]
|
||||||
|
d = image_features.shape[-1]
|
||||||
|
|
||||||
|
permuted_tensor = []
|
||||||
|
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
||||||
|
# Reshape image_tokens into a 2D grid
|
||||||
|
h, w = image_sizes[image_index]
|
||||||
|
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
grid = torch.nn.functional.unfold(
|
||||||
|
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
|
||||||
|
)
|
||||||
|
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
||||||
|
permuted_tensor.append(grid)
|
||||||
|
|
||||||
|
image_features = torch.cat(permuted_tensor, dim=0)
|
||||||
|
image_features = self.merging_layer(image_features)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3MultiModalProjector(nn.Module):
|
||||||
|
def __init__(self, config: Mistral3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size)
|
||||||
|
self.patch_merger = Mistral3PatchMerger(config)
|
||||||
|
# We have hidden_size * the number of vision feature layers
|
||||||
|
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
|
config.vision_config.hidden_size * num_feature_layers,
|
||||||
|
config.text_config.hidden_size,
|
||||||
|
bias=config.multimodal_projector_bias,
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
||||||
|
image_features = self.norm(image_features)
|
||||||
|
image_features = self.patch_merger(image_features, image_sizes)
|
||||||
|
hidden_states = self.linear_1(image_features)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
vision_feature_layer: Union[int, List[int]],
|
||||||
|
image_sizes: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
vision_feature_layer (`Union[int, List[int]]`):
|
||||||
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
|
vision features.
|
||||||
|
image_sizes (`torch.Tensor`):
|
||||||
|
Tensor containing the image sizes as returned by the processor.
|
||||||
|
Returns:
|
||||||
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
|
"""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||||
|
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||||
|
# If we have one vision feature layer, return the corresponding hidden states,
|
||||||
|
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||||
|
if isinstance(vision_feature_layer, int):
|
||||||
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||||
|
else:
|
||||||
|
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||||
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
image_sizes: torch.Tensor = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
|
||||||
|
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = outputs[0]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||||
|
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||||
|
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Mistral3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Mistral3PreTrainedModel", # noqa
|
||||||
|
"Mistral3ForConditionalGeneration",
|
||||||
|
]
|
||||||
@@ -128,8 +128,9 @@ def get_resize_output_image_size(
|
|||||||
|
|
||||||
if ratio > 1:
|
if ratio > 1:
|
||||||
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
||||||
height = int(math.ceil(height / ratio))
|
# Here we use floor to ensure the image is always smaller than the given "longest_edge"
|
||||||
width = int(math.ceil(width / ratio))
|
height = int(math.floor(height / ratio))
|
||||||
|
width = int(math.floor(width / ratio))
|
||||||
|
|
||||||
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
||||||
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
The tokenizer is a required input.
|
The tokenizer is a required input.
|
||||||
patch_size (`int`, *optional*, defaults to 16):
|
patch_size (`int`, *optional*, defaults to 16):
|
||||||
Patch size from the vision tower.
|
Patch size from the vision tower.
|
||||||
|
spatial_merge_size (`int`, *optional*, defaults to 1):
|
||||||
|
The downsampling factor for the spatial merge operation.
|
||||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
in a chat into a tokenizable string.
|
in a chat into a tokenizable string.
|
||||||
image_token (`str`, *optional*, defaults to `"[IMG]"`):
|
image_token (`str`, *optional*, defaults to `"[IMG]"`):
|
||||||
@@ -78,6 +80,7 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
valid_kwargs = [
|
valid_kwargs = [
|
||||||
"chat_template",
|
"chat_template",
|
||||||
"patch_size",
|
"patch_size",
|
||||||
|
"spatial_merge_size",
|
||||||
"image_token",
|
"image_token",
|
||||||
"image_break_token",
|
"image_break_token",
|
||||||
"image_end_token",
|
"image_end_token",
|
||||||
@@ -90,6 +93,7 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
image_processor=None,
|
image_processor=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
patch_size: int = 16,
|
patch_size: int = 16,
|
||||||
|
spatial_merge_size: int = 1,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
|
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||||
image_break_token="[IMG_BREAK]",
|
image_break_token="[IMG_BREAK]",
|
||||||
@@ -97,6 +101,7 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
self.image_token = image_token
|
self.image_token = image_token
|
||||||
self.image_break_token = image_break_token
|
self.image_break_token = image_break_token
|
||||||
self.image_end_token = image_end_token
|
self.image_end_token = image_end_token
|
||||||
@@ -187,8 +192,8 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
for sample in text:
|
for sample in text:
|
||||||
while self.image_token in sample:
|
while self.image_token in sample:
|
||||||
height, width = next(image_sizes)
|
height, width = next(image_sizes)
|
||||||
num_height_tokens = height // self.patch_size
|
num_height_tokens = height // (self.patch_size * self.spatial_merge_size)
|
||||||
num_width_tokens = width // self.patch_size
|
num_width_tokens = width // (self.patch_size * self.spatial_merge_size)
|
||||||
replace_tokens = [
|
replace_tokens = [
|
||||||
[self.image_token] * num_width_tokens + [self.image_break_token]
|
[self.image_token] * num_width_tokens + [self.image_break_token]
|
||||||
] * num_height_tokens
|
] * num_height_tokens
|
||||||
|
|||||||
@@ -6392,6 +6392,20 @@ class MistralPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3ForConditionalGeneration(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3PreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MixtralForCausalLM(metaclass=DummyObject):
|
class MixtralForCausalLM(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ VLM_CLASS_NAMES = [
|
|||||||
"qwen2_5_vl",
|
"qwen2_5_vl",
|
||||||
"ayavision",
|
"ayavision",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
|
"mistral3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
tests/models/mistral3/__init__.py
Normal file
0
tests/models/mistral3/__init__.py
Normal file
482
tests/models/mistral3/test_modeling_mistral3.py
Normal file
482
tests/models/mistral3/test_modeling_mistral3.py
Normal file
@@ -0,0 +1,482 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch GotOcr2 model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
Mistral3Config,
|
||||||
|
is_bitsandbytes_available,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
Mistral3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_bitsandbytes_available():
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral3VisionText2TextModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=3,
|
||||||
|
seq_length=7,
|
||||||
|
image_seq_length=4,
|
||||||
|
vision_feature_layer=-1,
|
||||||
|
ignore_index=-100,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=0,
|
||||||
|
pad_token_id=0,
|
||||||
|
image_token_index=1,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=30,
|
||||||
|
model_type="mistral3",
|
||||||
|
is_training=True,
|
||||||
|
text_config={
|
||||||
|
"model_type": "mistral",
|
||||||
|
"vocab_size": 99,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 32,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 37,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000000.0,
|
||||||
|
"sliding_window": None,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": 0,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"model_type": "pixtral",
|
||||||
|
"hidden_size": 32,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"intermediate_size": 37,
|
||||||
|
"image_size": 30,
|
||||||
|
"patch_size": 6,
|
||||||
|
"num_channels": 3,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.image_token_index = image_token_index
|
||||||
|
self.model_type = model_type
|
||||||
|
self.text_config = text_config
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.vision_feature_layer = vision_feature_layer
|
||||||
|
self.is_training = is_training
|
||||||
|
self.image_seq_length = image_seq_length
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.image_size = image_size
|
||||||
|
self.seq_length = seq_length + self.image_seq_length
|
||||||
|
|
||||||
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
|
self.vocab_size = text_config["vocab_size"]
|
||||||
|
self.hidden_size = text_config["hidden_size"]
|
||||||
|
self.num_attention_heads = text_config["num_attention_heads"]
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return Mistral3Config(
|
||||||
|
text_config=self.text_config,
|
||||||
|
vision_config=self.vision_config,
|
||||||
|
model_type=self.model_type,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
image_token_index=self.image_token_index,
|
||||||
|
image_seq_length=self.image_seq_length,
|
||||||
|
vision_feature_layer=self.vision_feature_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
config = self.get_config()
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
image_sizes = torch.tensor(
|
||||||
|
[[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# input_ids[:, -1] = self.pad_token_id
|
||||||
|
input_ids[input_ids == self.image_token_index] = self.pad_token_id
|
||||||
|
input_ids[:, : self.image_seq_length] = self.image_token_index
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"image_sizes": image_sizes,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||||
|
model = Mistral3ForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||||
|
config.torch_dtype = torch.float16
|
||||||
|
model = Mistral3ForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (Mistral3ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"image-text-to-text": Mistral3ForConditionalGeneration,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
_is_composite = True
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Mistral3VisionText2TextModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=Mistral3Config, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
# overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077
|
||||||
|
# TODO: avoid overwritten once there is a better fix for #36077
|
||||||
|
def check_config_can_be_init_without_params():
|
||||||
|
config = self.config_tester.config_class()
|
||||||
|
self.config_tester.parent.assertIsNotNone(config)
|
||||||
|
|
||||||
|
self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
del inputs["pixel_values"]
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
inputs["inputs_embeds"] = wte(input_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(**inputs)
|
||||||
|
|
||||||
|
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||||
|
# while some other models require pixel_values to be present
|
||||||
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
del inputs["pixel_values"]
|
||||||
|
|
||||||
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||||
|
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||||
|
torch.testing.assert_close(out_embeds, out_ids)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||||
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||||
|
def test_eager_matches_fa2_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||||
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||||
|
def test_flash_attn_2_from_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||||
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||||
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Pixtral does not support attention interfaces.")
|
||||||
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class Mistral3IntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
def test_mistral3_integration_generate_text_only(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||||
|
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Write a haiku"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "Sure, here's a haiku for you:\n\nWhispers of the breeze,\nCherry blossoms softly fall,\nSpring's gentle embrace."
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_mistral3_integration_generate(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||||
|
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_mistral3_integration_batched_generate(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||||
|
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
{"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
# Check first output
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's mirror gleams,\nWhispering pines"
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_mistral3_integration_batched_generate_multi_image(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||||
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
model = Mistral3ForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_checkpoint, quantization_config=quantization_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
{"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "These images depict two different landmarks. Can you identify them?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(model.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
# Check first output
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
expected_output = "Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n"
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = "These images depict two different landmarks. Can you identify them?Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
293
tests/models/mistral3/test_processor_mistral3.py
Normal file
293
tests/models/mistral3/test_processor_mistral3.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from transformers import PixtralProcessor
|
||||||
|
from transformers.testing_utils import require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
|
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
||||||
|
|
||||||
|
processor_class = PixtralProcessor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
|
||||||
|
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
|
||||||
|
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||||
|
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
processor = PixtralProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_chat_template(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
self.assertEqual(expected_prompt, formatted_prompt)
|
||||||
|
|
||||||
|
def test_image_token_filling(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
# Important to check with non square image
|
||||||
|
image = torch.randint(0, 2, (3, 500, 316))
|
||||||
|
expected_image_tokens = 198
|
||||||
|
image_token_index = 10
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
inputs = processor(
|
||||||
|
text=[processor.apply_chat_template(messages)],
|
||||||
|
images=[image],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||||
|
self.assertEqual(expected_image_tokens, image_tokens)
|
||||||
|
|
||||||
|
def test_processor_with_single_image(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
processor.patch_size = 6
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_image["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in a url
|
||||||
|
inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_url)
|
||||||
|
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_url["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing inputs as a single list
|
||||||
|
inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt")
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
inputs_image["input_ids"][0].tolist(),
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test as nested single list
|
||||||
|
inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt")
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
inputs_image["input_ids"][0].tolist(),
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_processor_with_multiple_images_single_list(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
processor.patch_size = 6
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_image["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in a url
|
||||||
|
inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_url)
|
||||||
|
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_url["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in as a nested list
|
||||||
|
inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt")
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
inputs_url["input_ids"][0].tolist(),
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_processor_with_multiple_images_multiple_lists(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
prompt_string = [
|
||||||
|
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
|
||||||
|
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||||
|
]
|
||||||
|
processor.tokenizer.pad_token = "</s>"
|
||||||
|
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
processor.patch_size = 6
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 2)
|
||||||
|
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_image["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in a url
|
||||||
|
inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||||
|
self.assertIn("input_ids", inputs_url)
|
||||||
|
self.assertTrue(len(inputs_url["input_ids"]) == 2)
|
||||||
|
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_url["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing as a single flat list
|
||||||
|
inputs_image = processor(
|
||||||
|
text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True
|
||||||
|
)
|
||||||
|
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
inputs_image["input_ids"][0].tolist(),
|
||||||
|
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_processor_returns_full_length_batches(self):
|
||||||
|
# to avoid https://github.com/huggingface/transformers/issues/34204
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
prompt_string = [
|
||||||
|
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||||
|
] * 5
|
||||||
|
processor.tokenizer.pad_token = "</s>"
|
||||||
|
image_inputs = [[self.image_0]] * 5
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
processor.patch_size = 6
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 5)
|
||||||
@@ -109,8 +109,8 @@ class PixtralImageProcessingTester:
|
|||||||
|
|
||||||
ratio = max(height / max_height, width / max_width)
|
ratio = max(height / max_height, width / max_width)
|
||||||
if ratio > 1:
|
if ratio > 1:
|
||||||
height = int(np.ceil(height / ratio))
|
height = int(np.floor(height / ratio))
|
||||||
width = int(np.ceil(width / ratio))
|
width = int(np.floor(width / ratio))
|
||||||
|
|
||||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||||
num_height_tokens = (height - 1) // patch_height + 1
|
num_height_tokens = (height - 1) // patch_height + 1
|
||||||
|
|||||||
Reference in New Issue
Block a user