Add support for Pixtral (#33449)
* initial commit * gloups * updates * work * weights match * nits * nits * updates to support the tokenizer :) * updates * Pixtral processor (#33454) * rough outline * Add in image break and end tokens * Fix * Udo some formatting changes * Set patch_size default * Fix * Fix token expansion * nit in conversion script * Fix image token list creation * done * add expected results * Process list of list of images (#33465) * updates * working image and processor * this is the expected format * some fixes * push current updated * working mult images! * add a small integration test * Uodate configuration docstring * Formatting * Config docstring fix * simplify model test * fixup modeling and etests * Return BatchMixFeature in image processor * fix some copies * update * nits * Update model docstring * Apply suggestions from code review * Fix up * updates * revert modeling changes * update * update * fix load safe * addd liscence * update * use pixel_values as required by the model * skip some tests and refactor * Add pixtral image processing tests (#33476) * Image processing tests * Add processing tests * woops * defaults reflect pixtral image processor * fixup post merge * images -> pixel values * oups sorry Mr docbuilder * isort * fix * fix processor tests * small fixes * nit * update * last nits * oups this was really breaking! * nits * is composition needs to be true --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -862,6 +862,8 @@
|
|||||||
title: Perceiver
|
title: Perceiver
|
||||||
- local: model_doc/pix2struct
|
- local: model_doc/pix2struct
|
||||||
title: Pix2Struct
|
title: Pix2Struct
|
||||||
|
- local: model_doc/pixtral
|
||||||
|
title: Pixtral
|
||||||
- local: model_doc/sam
|
- local: model_doc/sam
|
||||||
title: Segment Anything
|
title: Segment Anything
|
||||||
- local: model_doc/siglip
|
- local: model_doc/siglip
|
||||||
|
|||||||
@@ -253,6 +253,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
|
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
|
||||||
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
|
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
|
||||||
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
|
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
|
||||||
|
| [Pixtral](model_doc/pixtral) | ❌ | ❌ | ❌ |
|
||||||
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
|
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
|
||||||
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
|
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
|
||||||
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |
|
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |
|
||||||
|
|||||||
98
docs/source/en/model_doc/pixtral.md
Normal file
98
docs/source/en/model_doc/pixtral.md
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Pixtral
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!
|
||||||
|
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
|
||||||
|
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
|
||||||
|
- The format for one or mulitple prompts is the following:
|
||||||
|
```
|
||||||
|
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||||
|
```
|
||||||
|
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.
|
||||||
|
|
||||||
|
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)
|
||||||
|
|
||||||
|
Here is an example of how to run it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
model_id = "hf-internal-testing/pixtral-12b"
|
||||||
|
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cuda")
|
||||||
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
IMG_URLS = [
|
||||||
|
"https://picsum.photos/id/237/400/300",
|
||||||
|
"https://picsum.photos/id/231/200/300",
|
||||||
|
"https://picsum.photos/id/27/500/500",
|
||||||
|
"https://picsum.photos/id/17/150/600",
|
||||||
|
]
|
||||||
|
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
|
||||||
|
|
||||||
|
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=500)
|
||||||
|
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
|
||||||
|
EXPECTED_GENERATION = """
|
||||||
|
Describe the images.
|
||||||
|
Sure, let's break down each image description:
|
||||||
|
|
||||||
|
1. **Image 1:**
|
||||||
|
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
|
||||||
|
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
|
||||||
|
|
||||||
|
2. **Image 2:**
|
||||||
|
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
|
||||||
|
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
|
||||||
|
|
||||||
|
3. **Image 3:**
|
||||||
|
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
|
||||||
|
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
|
||||||
|
|
||||||
|
4. **Image 4:**
|
||||||
|
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
|
||||||
|
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
|
||||||
|
|
||||||
|
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
```
|
||||||
|
## PixtralVisionConfig
|
||||||
|
|
||||||
|
[[autodoc]] PixtralVisionConfig
|
||||||
|
|
||||||
|
## PixtralModel
|
||||||
|
|
||||||
|
[[autodoc]] PixtralModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PixtralImageProcessor
|
||||||
|
|
||||||
|
[[autodoc]] PixtralImageProcessor
|
||||||
|
- preprocess
|
||||||
|
|
||||||
|
## PixtralProcessor
|
||||||
|
|
||||||
|
[[autodoc]] PixtralProcessor
|
||||||
@@ -649,6 +649,7 @@ _import_structure = {
|
|||||||
"Pix2StructTextConfig",
|
"Pix2StructTextConfig",
|
||||||
"Pix2StructVisionConfig",
|
"Pix2StructVisionConfig",
|
||||||
],
|
],
|
||||||
|
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
|
||||||
"models.plbart": ["PLBartConfig"],
|
"models.plbart": ["PLBartConfig"],
|
||||||
"models.poolformer": ["PoolFormerConfig"],
|
"models.poolformer": ["PoolFormerConfig"],
|
||||||
"models.pop2piano": ["Pop2PianoConfig"],
|
"models.pop2piano": ["Pop2PianoConfig"],
|
||||||
@@ -1199,6 +1200,7 @@ else:
|
|||||||
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
|
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
|
||||||
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
|
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
|
||||||
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
|
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
|
||||||
|
_import_structure["models.pixtral"].append("PixtralImageProcessor")
|
||||||
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
|
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
|
||||||
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
|
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
|
||||||
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
|
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
|
||||||
@@ -1359,7 +1361,6 @@ else:
|
|||||||
"AlignVisionModel",
|
"AlignVisionModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
_import_structure["models.altclip"].extend(
|
_import_structure["models.altclip"].extend(
|
||||||
[
|
[
|
||||||
"AltCLIPModel",
|
"AltCLIPModel",
|
||||||
@@ -2977,6 +2978,7 @@ else:
|
|||||||
"Pix2StructVisionModel",
|
"Pix2StructVisionModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
|
||||||
_import_structure["models.plbart"].extend(
|
_import_structure["models.plbart"].extend(
|
||||||
[
|
[
|
||||||
"PLBartForCausalLM",
|
"PLBartForCausalLM",
|
||||||
@@ -5434,6 +5436,10 @@ if TYPE_CHECKING:
|
|||||||
Pix2StructTextConfig,
|
Pix2StructTextConfig,
|
||||||
Pix2StructVisionConfig,
|
Pix2StructVisionConfig,
|
||||||
)
|
)
|
||||||
|
from .models.pixtral import (
|
||||||
|
PixtralProcessor,
|
||||||
|
PixtralVisionConfig,
|
||||||
|
)
|
||||||
from .models.plbart import PLBartConfig
|
from .models.plbart import PLBartConfig
|
||||||
from .models.poolformer import (
|
from .models.poolformer import (
|
||||||
PoolFormerConfig,
|
PoolFormerConfig,
|
||||||
@@ -6009,6 +6015,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
|
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
|
||||||
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
|
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
|
||||||
from .models.pix2struct import Pix2StructImageProcessor
|
from .models.pix2struct import Pix2StructImageProcessor
|
||||||
|
from .models.pixtral import PixtralImageProcessor
|
||||||
from .models.poolformer import (
|
from .models.poolformer import (
|
||||||
PoolFormerFeatureExtractor,
|
PoolFormerFeatureExtractor,
|
||||||
PoolFormerImageProcessor,
|
PoolFormerImageProcessor,
|
||||||
@@ -7448,6 +7455,10 @@ if TYPE_CHECKING:
|
|||||||
Pix2StructTextModel,
|
Pix2StructTextModel,
|
||||||
Pix2StructVisionModel,
|
Pix2StructVisionModel,
|
||||||
)
|
)
|
||||||
|
from .models.pixtral import (
|
||||||
|
PixtralModel,
|
||||||
|
PixtralPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.plbart import (
|
from .models.plbart import (
|
||||||
PLBartForCausalLM,
|
PLBartForCausalLM,
|
||||||
PLBartForConditionalGeneration,
|
PLBartForConditionalGeneration,
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ from . import (
|
|||||||
phi3,
|
phi3,
|
||||||
phobert,
|
phobert,
|
||||||
pix2struct,
|
pix2struct,
|
||||||
|
pixtral,
|
||||||
plbart,
|
plbart,
|
||||||
poolformer,
|
poolformer,
|
||||||
pop2piano,
|
pop2piano,
|
||||||
|
|||||||
@@ -205,6 +205,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("phi", "PhiConfig"),
|
("phi", "PhiConfig"),
|
||||||
("phi3", "Phi3Config"),
|
("phi3", "Phi3Config"),
|
||||||
("pix2struct", "Pix2StructConfig"),
|
("pix2struct", "Pix2StructConfig"),
|
||||||
|
("pixtral", "PixtralVisionConfig"),
|
||||||
("plbart", "PLBartConfig"),
|
("plbart", "PLBartConfig"),
|
||||||
("poolformer", "PoolFormerConfig"),
|
("poolformer", "PoolFormerConfig"),
|
||||||
("pop2piano", "Pop2PianoConfig"),
|
("pop2piano", "Pop2PianoConfig"),
|
||||||
@@ -509,6 +510,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("phi3", "Phi3"),
|
("phi3", "Phi3"),
|
||||||
("phobert", "PhoBERT"),
|
("phobert", "PhoBERT"),
|
||||||
("pix2struct", "Pix2Struct"),
|
("pix2struct", "Pix2Struct"),
|
||||||
|
("pixtral", "Pixtral"),
|
||||||
("plbart", "PLBart"),
|
("plbart", "PLBart"),
|
||||||
("poolformer", "PoolFormer"),
|
("poolformer", "PoolFormer"),
|
||||||
("pop2piano", "Pop2Piano"),
|
("pop2piano", "Pop2Piano"),
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ else:
|
|||||||
("owlvit", ("OwlViTImageProcessor",)),
|
("owlvit", ("OwlViTImageProcessor",)),
|
||||||
("perceiver", ("PerceiverImageProcessor",)),
|
("perceiver", ("PerceiverImageProcessor",)),
|
||||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||||
|
("pixtral", ("PixtralImageProcessor",)),
|
||||||
("poolformer", ("PoolFormerImageProcessor",)),
|
("poolformer", ("PoolFormerImageProcessor",)),
|
||||||
("pvt", ("PvtImageProcessor",)),
|
("pvt", ("PvtImageProcessor",)),
|
||||||
("pvt_v2", ("PvtImageProcessor",)),
|
("pvt_v2", ("PvtImageProcessor",)),
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("persimmon", "PersimmonModel"),
|
("persimmon", "PersimmonModel"),
|
||||||
("phi", "PhiModel"),
|
("phi", "PhiModel"),
|
||||||
("phi3", "Phi3Model"),
|
("phi3", "Phi3Model"),
|
||||||
|
("pixtral", "PixtralModel"),
|
||||||
("plbart", "PLBartModel"),
|
("plbart", "PLBartModel"),
|
||||||
("poolformer", "PoolFormerModel"),
|
("poolformer", "PoolFormerModel"),
|
||||||
("prophetnet", "ProphetNetModel"),
|
("prophetnet", "ProphetNetModel"),
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("owlvit", "OwlViTProcessor"),
|
("owlvit", "OwlViTProcessor"),
|
||||||
("paligemma", "PaliGemmaProcessor"),
|
("paligemma", "PaliGemmaProcessor"),
|
||||||
("pix2struct", "Pix2StructProcessor"),
|
("pix2struct", "Pix2StructProcessor"),
|
||||||
|
("pixtral", "PixtralProcessor"),
|
||||||
("pop2piano", "Pop2PianoProcessor"),
|
("pop2piano", "Pop2PianoProcessor"),
|
||||||
("qwen2_audio", "Qwen2AudioProcessor"),
|
("qwen2_audio", "Qwen2AudioProcessor"),
|
||||||
("qwen2_vl", "Qwen2VLProcessor"),
|
("qwen2_vl", "Qwen2VLProcessor"),
|
||||||
|
|||||||
@@ -385,6 +385,7 @@ else:
|
|||||||
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("phobert", ("PhobertTokenizer", None)),
|
("phobert", ("PhobertTokenizer", None)),
|
||||||
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class LlavaConfig(PretrainedConfig):
|
|||||||
```"""
|
```"""
|
||||||
|
|
||||||
model_type = "llava"
|
model_type = "llava"
|
||||||
is_composition = False
|
is_composition = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
70
src/transformers/models/pixtral/__init__.py
Normal file
70
src/transformers/models/pixtral/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {
|
||||||
|
"configuration_pixtral": ["PixtralVisionConfig"],
|
||||||
|
"processing_pixtral": ["PixtralProcessor"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["modeling_pixtral"] = [
|
||||||
|
"PixtralModel",
|
||||||
|
"PixtralPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_vision_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .modeling_pixtral import (
|
||||||
|
PixtralModel,
|
||||||
|
PixtralPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_vision_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .image_processing_pixtral import PixtralImageProcessor
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||||
103
src/transformers/models/pixtral/configuration_pixtral.py
Normal file
103
src/transformers/models/pixtral/configuration_pixtral.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Pixtral model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralVisionConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an
|
||||||
|
Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the Pixtral-9B.
|
||||||
|
|
||||||
|
e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads in the Transformer encoder.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
Number of input channels in the input images.
|
||||||
|
image_size (`int`, *optional*, defaults to 1024):
|
||||||
|
Max dimension of the input images.
|
||||||
|
patch_size (`int`, *optional*, defaults to 16):
|
||||||
|
Size of the image patches.
|
||||||
|
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
Activation function used in the hidden layers.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
Dropout probability for the attention layers.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie the word embeddings with the input embeddings.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Pixtral 12B style configuration
|
||||||
|
>>> config = PixtralVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the pixtral 12B style configuration
|
||||||
|
>>> model = PixtralModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "pixtral"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=4096,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=1024,
|
||||||
|
patch_size=16,
|
||||||
|
hidden_act="gelu",
|
||||||
|
attention_dropout=0.0,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.head_dim = hidden_size // num_attention_heads
|
||||||
285
src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py
Normal file
285
src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
|
from safetensors.torch import load_file as safe_load_file
|
||||||
|
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
|
||||||
|
from tokenizers.models import BPE
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
LlavaConfig,
|
||||||
|
LlavaForConditionalGeneration,
|
||||||
|
MistralConfig,
|
||||||
|
PixtralImageProcessor,
|
||||||
|
PixtralProcessor,
|
||||||
|
PixtralVisionConfig,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
)
|
||||||
|
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Here is how to get the original tokens!
|
||||||
|
model_name = "mistralai/Pixtral-12B-2409"
|
||||||
|
tok = MistralTokenizer.from_model(model_name)
|
||||||
|
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest, UserMessage, ImageChunk, TextChunk
|
||||||
|
|
||||||
|
EXPECTED_TOKENS = tok.encode_chat_completion(
|
||||||
|
ChatCompletionRequest(
|
||||||
|
messages=[
|
||||||
|
UserMessage(
|
||||||
|
content=[
|
||||||
|
TextChunk(text="Describe the images"),
|
||||||
|
] + [ImageChunk(image=img) for img in IMG_URLS]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="pixtral",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert tokenizer.decode(inputs["input_ids"][0]) == EXPECTED_TOKENS
|
||||||
|
"""
|
||||||
|
|
||||||
|
OLD_KEY_TO_NEW_KEY_MAPPING = {
|
||||||
|
# Layer Normalization Weights
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight",
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight",
|
||||||
|
# Self Attention Projections
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight",
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight",
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight",
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight",
|
||||||
|
# MLP Projections
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
|
||||||
|
r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
|
||||||
|
# Additional mappings
|
||||||
|
r"vision_encoder": r"vision_tower",
|
||||||
|
r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1",
|
||||||
|
r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2",
|
||||||
|
r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight",
|
||||||
|
r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight",
|
||||||
|
r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight",
|
||||||
|
r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight",
|
||||||
|
r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
|
||||||
|
r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
|
||||||
|
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
|
||||||
|
r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||||
|
r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
|
||||||
|
r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
|
||||||
|
r"output.weight": r"language_model.lm_head.weight",
|
||||||
|
r"norm.weight": r"language_model.model.norm.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConverter:
|
||||||
|
"""
|
||||||
|
A general tiktoken converter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab=None,
|
||||||
|
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||||
|
add_prefix_space=False,
|
||||||
|
additional_special_tokens=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args)
|
||||||
|
self.vocab = vocab
|
||||||
|
self.pattern = pattern
|
||||||
|
self.add_prefix_space = add_prefix_space
|
||||||
|
self.additional_special_tokens = additional_special_tokens
|
||||||
|
|
||||||
|
def extract_vocab_merges_from_model(self, vocab: str):
|
||||||
|
bpe_ranks = vocab
|
||||||
|
byte_encoder = bytes_to_unicode()
|
||||||
|
|
||||||
|
def token_bytes_to_string(b):
|
||||||
|
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||||
|
|
||||||
|
merges = []
|
||||||
|
vocab = {}
|
||||||
|
for idx, (token, rank) in enumerate(bpe_ranks.items()):
|
||||||
|
if token not in self.additional_special_tokens:
|
||||||
|
vocab[token_bytes_to_string(token)] = idx
|
||||||
|
if len(token) == 1:
|
||||||
|
continue
|
||||||
|
local = []
|
||||||
|
for index in range(1, len(token)):
|
||||||
|
piece_l, piece_r = token[:index], token[index:]
|
||||||
|
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
|
||||||
|
local.append((piece_l, piece_r, rank))
|
||||||
|
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
|
||||||
|
merges.extend(local)
|
||||||
|
else:
|
||||||
|
vocab[token] = idx
|
||||||
|
merges = sorted(merges, key=lambda val: val[2], reverse=False)
|
||||||
|
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
|
||||||
|
return vocab, merges
|
||||||
|
|
||||||
|
def tokenizer(self):
|
||||||
|
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
|
||||||
|
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
|
||||||
|
if hasattr(tokenizer.model, "ignore_merges"):
|
||||||
|
tokenizer.model.ignore_merges = True
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def converted(self) -> Tokenizer:
|
||||||
|
tokenizer = self.tokenizer()
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||||
|
[
|
||||||
|
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
|
||||||
|
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tokenizer.decoder = decoders.ByteLevel()
|
||||||
|
tokenizer.add_special_tokens(self.additional_special_tokens)
|
||||||
|
|
||||||
|
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mistral_tokenizer():
|
||||||
|
model_name = "mistralai/Pixtral-12B-2409"
|
||||||
|
|
||||||
|
tokenizer = MistralTokenizer.from_model(model_name)
|
||||||
|
|
||||||
|
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
|
||||||
|
all_special = [
|
||||||
|
token.value if hasattr(token, "value") else token
|
||||||
|
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
|
||||||
|
]
|
||||||
|
specials_tokens = {token: all_special.index(token) for token in all_special}
|
||||||
|
specials_tokens.update(vocab)
|
||||||
|
vocab = specials_tokens
|
||||||
|
|
||||||
|
tokenizer = PreTrainedTokenizerFast(
|
||||||
|
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
|
||||||
|
bos_token="<s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
eos_token="</s>",
|
||||||
|
)
|
||||||
|
tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def permute_for_rope(value, n_heads, config):
|
||||||
|
dim1 = value.shape[0]
|
||||||
|
dim2 = config.hidden_size
|
||||||
|
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dictionnary(original_state_dict, vision_config, text_config):
|
||||||
|
new_dict = {}
|
||||||
|
|
||||||
|
all_keys = "\n" + "\n".join(original_state_dict.keys())
|
||||||
|
old_keys = all_keys
|
||||||
|
for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items():
|
||||||
|
all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys)
|
||||||
|
|
||||||
|
OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n")))
|
||||||
|
|
||||||
|
for key, value in original_state_dict.items():
|
||||||
|
new_key = OLD_TO_NEW[key]
|
||||||
|
if "vision_encoder" in key:
|
||||||
|
_config = vision_config
|
||||||
|
num_attention_heads = _config.num_attention_heads
|
||||||
|
else:
|
||||||
|
_config = text_config
|
||||||
|
if "q_proj" in new_key:
|
||||||
|
num_attention_heads = _config.num_attention_heads
|
||||||
|
if "k_proj" in new_key:
|
||||||
|
num_attention_heads = _config.num_key_value_heads
|
||||||
|
# convert the text model (basically mistral model)
|
||||||
|
|
||||||
|
if "q_proj" in new_key or "k_proj" in new_key:
|
||||||
|
value = permute_for_rope(value, num_attention_heads, _config)
|
||||||
|
|
||||||
|
new_dict[new_key] = value
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mistral_model(input_dir, output_dir):
|
||||||
|
text_config = MistralConfig(
|
||||||
|
attention_dropout=0.0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
head_dim=128,
|
||||||
|
hidden_act="silu",
|
||||||
|
hidden_size=5120,
|
||||||
|
initializer_range=0.02,
|
||||||
|
intermediate_size=14336,
|
||||||
|
max_position_embeddings=1024000,
|
||||||
|
model_type="mistral",
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
rope_theta=1000000000.0,
|
||||||
|
sliding_window=None,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
vocab_size=131072,
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_config = PixtralVisionConfig()
|
||||||
|
config = LlavaConfig(
|
||||||
|
vision_config,
|
||||||
|
text_config,
|
||||||
|
vision_feature_layer=-1,
|
||||||
|
image_token_index=10,
|
||||||
|
vision_feature_select_strategy="full",
|
||||||
|
image_seq_length=1,
|
||||||
|
)
|
||||||
|
config.architectures = ["LlavaForConditionalGeneration"]
|
||||||
|
config.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
|
||||||
|
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
|
||||||
|
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = LlavaForConditionalGeneration(config)
|
||||||
|
model.load_state_dict(new_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
tokenizer = convert_mistral_tokenizer()
|
||||||
|
image_processor = PixtralImageProcessor()
|
||||||
|
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
|
||||||
|
processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_dir",
|
||||||
|
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
help="Location to write HF model and tokenizer",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_mistral_model(args.input_dir, args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
519
src/transformers/models/pixtral/image_processing_pixtral.py
Normal file
519
src/transformers/models/pixtral/image_processing_pixtral.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Pixtral."""
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_scaled_image,
|
||||||
|
is_valid_image,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
validate_kwargs,
|
||||||
|
validate_preprocess_arguments,
|
||||||
|
)
|
||||||
|
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
|
||||||
|
from ...utils.import_utils import requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
class BatchMixFeature(BatchFeature):
|
||||||
|
def to(self, *args, **kwargs) -> "BatchMixFeature":
|
||||||
|
"""
|
||||||
|
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||||
|
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (`Tuple`):
|
||||||
|
Will be passed to the `to(...)` function of the tensors.
|
||||||
|
kwargs (`Dict`, *optional*):
|
||||||
|
Will be passed to the `to(...)` function of the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`BatchFeature`]: The same instance after modification.
|
||||||
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
import torch # noqa
|
||||||
|
|
||||||
|
new_data = {}
|
||||||
|
device = kwargs.get("device")
|
||||||
|
# Check if the args are a device or a dtype
|
||||||
|
if device is None and len(args) > 0:
|
||||||
|
# device should be always the first argument
|
||||||
|
arg = args[0]
|
||||||
|
if is_torch_dtype(arg):
|
||||||
|
# The first argument is a dtype
|
||||||
|
pass
|
||||||
|
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||||
|
device = arg
|
||||||
|
else:
|
||||||
|
# it's something else
|
||||||
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||||
|
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||||
|
for k, v in self.items():
|
||||||
|
# check if v is a floating point
|
||||||
|
if isinstance(v, list):
|
||||||
|
new_data[k] = [
|
||||||
|
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||||
|
]
|
||||||
|
elif torch.is_floating_point(v):
|
||||||
|
# cast and send to device
|
||||||
|
new_data[k] = v.to(*args, **kwargs)
|
||||||
|
elif device is not None:
|
||||||
|
new_data[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_data[k] = v
|
||||||
|
self.data = new_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
|
||||||
|
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Convert a single image or a list of images to a list of numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
A single image or a list of images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of numpy arrays.
|
||||||
|
"""
|
||||||
|
# If it's a single image, convert it to a list of lists
|
||||||
|
if is_valid_image(images):
|
||||||
|
images = [[images]]
|
||||||
|
# If it's a list of images, it's a single batch, so convert it to a list of lists
|
||||||
|
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
|
||||||
|
images = [images]
|
||||||
|
# If it's a list of batches, it's already in the right format
|
||||||
|
elif (
|
||||||
|
isinstance(images, (list, tuple))
|
||||||
|
and len(images) > 0
|
||||||
|
and isinstance(images[0], (list, tuple))
|
||||||
|
and is_valid_image(images[0][0])
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
|
||||||
|
)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
|
||||||
|
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||||
|
"""
|
||||||
|
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
||||||
|
as is.
|
||||||
|
Args:
|
||||||
|
image (Image):
|
||||||
|
The image to convert.
|
||||||
|
"""
|
||||||
|
requires_backends(convert_to_rgb, ["vision"])
|
||||||
|
|
||||||
|
if not isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
if image.mode == "RGB":
|
||||||
|
return image
|
||||||
|
|
||||||
|
# First we convert to RGBA to set background to white.
|
||||||
|
image = image.convert("RGBA")
|
||||||
|
|
||||||
|
# Create a new image with a white background.
|
||||||
|
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
|
||||||
|
new_image.paste(image, (0, 0), image)
|
||||||
|
new_image = new_image.convert("RGB")
|
||||||
|
return new_image
|
||||||
|
|
||||||
|
|
||||||
|
def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the number of image tokens given the image size and patch size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (`Tuple[int, int]`):
|
||||||
|
The size of the image as `(height, width)`.
|
||||||
|
patch_size (`Tuple[int, int]`):
|
||||||
|
The patch size as `(height, width)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`: The number of image tokens.
|
||||||
|
"""
|
||||||
|
height, width = image_size
|
||||||
|
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
|
||||||
|
num_width_tokens = (width - 1) // patch_width + 1
|
||||||
|
num_height_tokens = (height - 1) // patch_height + 1
|
||||||
|
return num_height_tokens, num_width_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_output_image_size(
|
||||||
|
input_image: np.ndarray,
|
||||||
|
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||||
|
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
||||||
|
size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image (`np.ndarray`):
|
||||||
|
The image to resize.
|
||||||
|
size (`int` or `Tuple[int, int]`):
|
||||||
|
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
|
||||||
|
patch_size (`int` or `Tuple[int, int]`):
|
||||||
|
The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)`
|
||||||
|
will be used
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tuple`: The target (height, width) dimension of the output image after resizing.
|
||||||
|
"""
|
||||||
|
max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size)
|
||||||
|
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
|
||||||
|
height, width = get_image_size(input_image, input_data_format)
|
||||||
|
|
||||||
|
ratio = max(height / max_height, width / max_width)
|
||||||
|
|
||||||
|
if ratio > 1:
|
||||||
|
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
||||||
|
height = int(np.ceil(height / ratio))
|
||||||
|
width = int(np.ceil(width / ratio))
|
||||||
|
|
||||||
|
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
||||||
|
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
||||||
|
|
||||||
|
|
||||||
|
# Hack to get tensor conversion used in BatchFeature without batching the images
|
||||||
|
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
|
||||||
|
return BatchFeature()._get_is_as_tensor_fns(tensor_type)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
|
||||||
|
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
|
||||||
|
if is_tensor(array):
|
||||||
|
return array
|
||||||
|
return as_tensor(array)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a Pixtral image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
|
||||||
|
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
|
||||||
|
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
|
||||||
|
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
|
||||||
|
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
patch_size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"longest_edge": 1024}
|
||||||
|
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
|
||||||
|
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
self._valid_processor_keys = [
|
||||||
|
"images",
|
||||||
|
"do_resize",
|
||||||
|
"size",
|
||||||
|
"patch_size",
|
||||||
|
"resample",
|
||||||
|
"do_rescale",
|
||||||
|
"rescale_factor",
|
||||||
|
"do_normalize",
|
||||||
|
"image_mean",
|
||||||
|
"image_std",
|
||||||
|
"do_convert_rgb",
|
||||||
|
"return_tensors",
|
||||||
|
"data_format",
|
||||||
|
"input_data_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
patch_size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||||
|
resized to keep the input aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Dict containing the longest possible edge of the image.
|
||||||
|
patch_size (`Dict[str, int]`):
|
||||||
|
Patch size used to calculate the size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
if "longest_edge" in size:
|
||||||
|
size = (size["longest_edge"], size["longest_edge"])
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
size = (size["height"], size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
|
||||||
|
|
||||||
|
if "height" in patch_size and "width" in patch_size:
|
||||||
|
patch_size = (patch_size["height"], patch_size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||||||
|
|
||||||
|
output_size = get_resize_output_image_size(
|
||||||
|
image,
|
||||||
|
size=size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
return resize(
|
||||||
|
image,
|
||||||
|
size=output_size,
|
||||||
|
resample=resample,
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
patch_size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||||
|
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Describes the maximum input dimensions to the model.
|
||||||
|
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||||
|
Patch size in the model. Used to calculate the image after resizing.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the channel dimension format of the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||||
|
from the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
"""
|
||||||
|
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||||
|
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||||
|
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||||
|
|
||||||
|
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||||
|
|
||||||
|
images_list = make_list_of_images(images)
|
||||||
|
|
||||||
|
if not valid_images(images_list[0]):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_preprocess_arguments(
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_convert_rgb:
|
||||||
|
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
if is_scaled_image(images_list[0][0]) and do_rescale:
|
||||||
|
logger.warning_once(
|
||||||
|
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||||
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_data_format is None:
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_image_sizes = []
|
||||||
|
for sample_images in images_list:
|
||||||
|
images = []
|
||||||
|
image_sizes = []
|
||||||
|
for image in sample_images:
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(
|
||||||
|
image=image,
|
||||||
|
size=size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
resample=resample,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
image = self.normalize(
|
||||||
|
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||||
|
)
|
||||||
|
|
||||||
|
images.append(image)
|
||||||
|
image_sizes.append(get_image_size(image, input_data_format))
|
||||||
|
batch_images.append(images)
|
||||||
|
batch_image_sizes.append(image_sizes)
|
||||||
|
|
||||||
|
images_list = [
|
||||||
|
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
|
||||||
|
for images in batch_images
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
|
||||||
|
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
|
||||||
|
return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)
|
||||||
517
src/transformers/models/pixtral/modeling_pixtral.py
Normal file
517
src/transformers/models/pixtral/modeling_pixtral.py
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Pixtral model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ... import PreTrainedModel
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_outputs import BaseModelOutput
|
||||||
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from .configuration_pixtral import PixtralVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def position_ids_in_meshgrid(patch_embeds_list, max_width):
|
||||||
|
positions = []
|
||||||
|
for patch in patch_embeds_list:
|
||||||
|
height, width = patch.shape[-2:]
|
||||||
|
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
||||||
|
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
|
||||||
|
ids = h_grid * max_width + v_grid
|
||||||
|
positions.append(ids[:, 0])
|
||||||
|
return torch.cat(positions)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralRotaryEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
The key with pixtral embedding is just that you have a frequency for each pixel positions.
|
||||||
|
If you have height x width pixels (or embedding pixels)
|
||||||
|
|
||||||
|
then the frequency used for ROPE is given by indexing the pre_computed frequency on the
|
||||||
|
width and height.
|
||||||
|
|
||||||
|
What you output is of dimension batch, height * width, dim with dim the embed dim.
|
||||||
|
|
||||||
|
This simply means that for each image hidden states, you are going to add
|
||||||
|
a corresponding positional embedding, based on it's index in the grid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, device):
|
||||||
|
super().__init__()
|
||||||
|
self.rope_type = "default"
|
||||||
|
self.dim = config.head_dim
|
||||||
|
self.base = config.rope_theta
|
||||||
|
max_patches_per_side = config.image_size // config.patch_size
|
||||||
|
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||||
|
|
||||||
|
h = torch.arange(max_patches_per_side, device=freqs.device)
|
||||||
|
w = torch.arange(max_patches_per_side, device=freqs.device)
|
||||||
|
|
||||||
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
||||||
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
||||||
|
inv_freq = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
|
||||||
|
freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
|
||||||
|
# TODO maybe make it torch compatible later on. We can also just slice
|
||||||
|
self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
if "dynamic" in self.rope_type:
|
||||||
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||||
|
|
||||||
|
# Core RoPE block
|
||||||
|
freqs = self.inv_freq[position_ids]
|
||||||
|
# position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||||
|
device_type = x.device.type
|
||||||
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
emb = freqs
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
def _dynamic_frequency_update(self, position_ids, device):
|
||||||
|
"""
|
||||||
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||||
|
1 - growing beyond the cached sequence length (allow scaling)
|
||||||
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||||
|
"""
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||||
|
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
batch_size, patches, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, patches, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
|
||||||
|
class PixtralMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
|
||||||
|
class PixtralRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
PixtralRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralAttentionLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
self.feed_forward = PixtralMLP(config)
|
||||||
|
self.attention = PixtralAttention(config)
|
||||||
|
self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.FloatTensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`):
|
||||||
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||||
|
attention_mask (`torch.FloatTensor`):
|
||||||
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.attention_norm(hidden_states)
|
||||||
|
hidden_states, attn_weights = self.attention(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ffn_norm(hidden_states)
|
||||||
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralTransformer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for _ in range(config.num_hidden_layers):
|
||||||
|
self.layers.append(PixtralAttentionLayer(config))
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
encoder_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
encoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_embeddings,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PIXTRAL_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]):
|
||||||
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||||
|
load the weights associated with the model, only the configuration. Check out the
|
||||||
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||||
|
PIXTRAL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class PixtralPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = PixtralVisionConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["PixtralVisionAttention"]
|
||||||
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
# important: this ported version of Pixtral isn't meant for training from scratch - only
|
||||||
|
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||||
|
# https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose
|
||||||
|
std = (
|
||||||
|
self.config.initializer_range
|
||||||
|
if hasattr(self.config, "initializer_range")
|
||||||
|
else self.config.text_config.initializer_range
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
|
PIXTRAL_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values: list of N_img images of variable sizes,
|
||||||
|
each of shape (C, H, W)
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_block_attention_mask(patch_embeds_list, tensor):
|
||||||
|
dtype = tensor.dtype
|
||||||
|
device = tensor.device
|
||||||
|
seq_len = tensor.shape[1]
|
||||||
|
d_min = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
|
||||||
|
block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
|
||||||
|
for start, end in zip(block_start_idx, block_end_idx):
|
||||||
|
causal_mask[start:end, start:end] = 0
|
||||||
|
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""The PIXTRAL model which consists of a vision backbone and a language model.""",
|
||||||
|
PIXTRAL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class PixtralModel(PixtralPreTrainedModel):
|
||||||
|
base_model_prefix = "vision_encoder"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.patch_conv = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=config.hidden_size,
|
||||||
|
kernel_size=config.patch_size,
|
||||||
|
stride=config.patch_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
self.transformer = PixtralTransformer(config)
|
||||||
|
self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: List[torch.Tensor],
|
||||||
|
output_hidden_states: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
pixel_values: tensor of token features for
|
||||||
|
all tokens of all images of shape (N_toks, D)
|
||||||
|
"""
|
||||||
|
# pass images through initial convolution independently
|
||||||
|
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
|
||||||
|
|
||||||
|
# flatten to a single sequence
|
||||||
|
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
||||||
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
position_ids = position_ids_in_meshgrid(
|
||||||
|
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
||||||
|
attention_mask = generate_block_attention_mask(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||||
|
)
|
||||||
|
return self.transformer(patch_embeds, attention_mask, position_embedding)
|
||||||
282
src/transformers/models/pixtral/processing_pixtral.py
Normal file
282
src/transformers/models/pixtral/processing_pixtral.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Processor class for Pixtral.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from ...feature_extraction_utils import BatchFeature
|
||||||
|
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||||
|
from ...processing_utils import ProcessorMixin
|
||||||
|
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||||
|
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.idefics2.processing_idefics2.is_url
|
||||||
|
def is_url(val) -> bool:
|
||||||
|
return isinstance(val, str) and val.startswith("http")
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
|
||||||
|
def is_image_or_image_url(elem):
|
||||||
|
return is_url(elem) or is_valid_image(elem)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
|
||||||
|
class BatchMixFeature(BatchFeature):
|
||||||
|
def to(self, *args, **kwargs) -> "BatchMixFeature":
|
||||||
|
"""
|
||||||
|
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||||
|
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (`Tuple`):
|
||||||
|
Will be passed to the `to(...)` function of the tensors.
|
||||||
|
kwargs (`Dict`, *optional*):
|
||||||
|
Will be passed to the `to(...)` function of the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`BatchFeature`]: The same instance after modification.
|
||||||
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
import torch # noqa
|
||||||
|
|
||||||
|
new_data = {}
|
||||||
|
device = kwargs.get("device")
|
||||||
|
# Check if the args are a device or a dtype
|
||||||
|
if device is None and len(args) > 0:
|
||||||
|
# device should be always the first argument
|
||||||
|
arg = args[0]
|
||||||
|
if is_torch_dtype(arg):
|
||||||
|
# The first argument is a dtype
|
||||||
|
pass
|
||||||
|
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||||
|
device = arg
|
||||||
|
else:
|
||||||
|
# it's something else
|
||||||
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||||
|
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||||
|
for k, v in self.items():
|
||||||
|
# check if v is a floating point
|
||||||
|
if isinstance(v, list):
|
||||||
|
new_data[k] = [
|
||||||
|
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||||
|
]
|
||||||
|
elif torch.is_floating_point(v):
|
||||||
|
# cast and send to device
|
||||||
|
new_data[k] = v.to(*args, **kwargs)
|
||||||
|
elif device is not None:
|
||||||
|
new_data[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_data[k] = v
|
||||||
|
self.data = new_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralProcessor(ProcessorMixin):
|
||||||
|
r"""
|
||||||
|
Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor.
|
||||||
|
|
||||||
|
[`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
||||||
|
[`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_processor ([`PixtralImageProcessor`], *optional*):
|
||||||
|
The image processor is a required input.
|
||||||
|
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||||
|
The tokenizer is a required input.
|
||||||
|
patch_size (`int`, *optional*, defaults to 16):
|
||||||
|
Patch size from the vision tower.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
|
image_token (`str`, *optional*, defaults to `"[IMG]"`):
|
||||||
|
Special token used to denote image location.
|
||||||
|
image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`):
|
||||||
|
Special token used to denote the end of a line of pixels in an image.
|
||||||
|
image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`):
|
||||||
|
Special token used to denote the end of an image input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = [
|
||||||
|
"chat_template",
|
||||||
|
"patch_size",
|
||||||
|
"image_token",
|
||||||
|
"image_break_token",
|
||||||
|
"image_end_token",
|
||||||
|
]
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor=None,
|
||||||
|
tokenizer=None,
|
||||||
|
patch_size: int = 16,
|
||||||
|
chat_template=None,
|
||||||
|
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||||
|
image_break_token="[IMG_BREAK]",
|
||||||
|
image_end_token="[IMG_END]",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_token = image_token
|
||||||
|
self.image_break_token = image_break_token
|
||||||
|
self.image_end_token = image_end_token
|
||||||
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||||
|
images: ImageInput = None,
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = False,
|
||||||
|
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||||
|
max_length=None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
||||||
|
) -> BatchMixFeature:
|
||||||
|
"""
|
||||||
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||||
|
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||||
|
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||||
|
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||||
|
of the above two methods for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (`str`, `List[str]`, `List[List[str]]`):
|
||||||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. Both channels-first and channels-last formats are supported.
|
||||||
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||||
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||||
|
index) among:
|
||||||
|
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||||
|
sequence if provided).
|
||||||
|
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||||
|
acceptable input length for the model if that argument is not provided.
|
||||||
|
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||||
|
lengths).
|
||||||
|
max_length (`int`, *optional*):
|
||||||
|
Maximum length of the returned list and optionally padding length (see above).
|
||||||
|
truncation (`bool`, *optional*):
|
||||||
|
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||||
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
|
If set, will return tensors of a particular framework. Acceptable values are:
|
||||||
|
|
||||||
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||||
|
|
||||||
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||||
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||||
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||||
|
`None`).
|
||||||
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
|
"""
|
||||||
|
if images is not None:
|
||||||
|
if is_image_or_image_url(images):
|
||||||
|
images = [[images]]
|
||||||
|
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||||
|
images = [images]
|
||||||
|
elif (
|
||||||
|
not isinstance(images, list)
|
||||||
|
and not isinstance(images[0], list)
|
||||||
|
and not is_image_or_image_url(images[0][0])
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
||||||
|
)
|
||||||
|
images = [[load_image(im) for im in sample] for sample in images]
|
||||||
|
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
|
||||||
|
else:
|
||||||
|
image_inputs = {}
|
||||||
|
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
|
# try to expand inputs in processing if we have the necessary parts
|
||||||
|
prompt_strings = text
|
||||||
|
if image_inputs.get("pixel_values") is not None:
|
||||||
|
# Replace the image token with the expanded image token sequence
|
||||||
|
images = image_inputs["pixel_values"]
|
||||||
|
image_sizes = image_inputs.pop("image_sizes")
|
||||||
|
prompt_strings = []
|
||||||
|
|
||||||
|
for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
|
||||||
|
replace_strings = []
|
||||||
|
# First calculate the number of tokens needed for each image and put in a placeholder
|
||||||
|
for image, image_size in zip(sample_images, sample_image_sizes):
|
||||||
|
height, width = image_size
|
||||||
|
num_height_tokens = height // self.patch_size
|
||||||
|
num_width_tokens = width // self.patch_size
|
||||||
|
replace_tokens = [
|
||||||
|
[self.image_token] * num_width_tokens + [self.image_break_token]
|
||||||
|
] * num_height_tokens
|
||||||
|
# Flatten list
|
||||||
|
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||||
|
replace_tokens[-1] = self.image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
replace_strings.append(replace_str)
|
||||||
|
sample = sample.replace(self.image_token, "<placeholder>", 1)
|
||||||
|
|
||||||
|
while "<placeholder>" in sample:
|
||||||
|
replace_str = replace_strings.pop(0)
|
||||||
|
sample = sample.replace("<placeholder>", replace_str, 1)
|
||||||
|
|
||||||
|
prompt_strings.append(sample)
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
prompt_strings,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
return BatchMixFeature(data={**text_inputs, **image_inputs})
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
refer to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||||
|
the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||||
@@ -7067,6 +7067,20 @@ class Pix2StructVisionModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class PLBartForCausalLM(metaclass=DummyObject):
|
class PLBartForCausalLM(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -506,6 +506,13 @@ class Pix2StructImageProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["vision"])
|
requires_backends(self, ["vision"])
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralImageProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["vision"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["vision"])
|
||||||
|
|
||||||
|
|
||||||
class PoolFormerFeatureExtractor(metaclass=DummyObject):
|
class PoolFormerFeatureExtractor(metaclass=DummyObject):
|
||||||
_backends = ["vision"]
|
_backends = ["vision"]
|
||||||
|
|
||||||
|
|||||||
@@ -569,3 +569,50 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# check that both inputs are handled correctly and generate the same output
|
# check that both inputs are handled correctly and generate the same output
|
||||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_pixtral(self):
|
||||||
|
model_id = "hf-internal-testing/pixtral-12b"
|
||||||
|
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
||||||
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
IMG_URLS = [
|
||||||
|
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
|
||||||
|
Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw),
|
||||||
|
Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw),
|
||||||
|
Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw),
|
||||||
|
]
|
||||||
|
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
|
||||||
|
|
||||||
|
# image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=500)
|
||||||
|
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_GENERATION = """
|
||||||
|
Describe the images.
|
||||||
|
Sure, let's break down each image description:
|
||||||
|
|
||||||
|
1. **Image 1:**
|
||||||
|
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
|
||||||
|
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
|
||||||
|
|
||||||
|
2. **Image 2:**
|
||||||
|
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
|
||||||
|
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
|
||||||
|
|
||||||
|
3. **Image 3:**
|
||||||
|
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
|
||||||
|
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
|
||||||
|
|
||||||
|
4. **Image 4:**
|
||||||
|
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
|
||||||
|
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
|
||||||
|
|
||||||
|
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
|
||||||
|
"""
|
||||||
|
# fmt: on
|
||||||
|
# check that both inputs are handled correctly and generate the same output
|
||||||
|
self.assertListEqual(ouptut, EXPECTED_GENERATION)
|
||||||
|
|||||||
0
tests/models/pixtral/__init__.py
Normal file
0
tests/models/pixtral/__init__.py
Normal file
217
tests/models/pixtral/test_image_processing_pixtral.py
Normal file
217
tests/models/pixtral/test_image_processing_pixtral.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import PixtralImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralImageProcessingTester(unittest.TestCase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=7,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=18,
|
||||||
|
max_num_images_per_sample=3,
|
||||||
|
min_resolution=30,
|
||||||
|
max_resolution=400,
|
||||||
|
do_resize=True,
|
||||||
|
size=None,
|
||||||
|
patch_size=None,
|
||||||
|
do_normalize=True,
|
||||||
|
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||||
|
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||||
|
do_convert_rgb=True,
|
||||||
|
):
|
||||||
|
size = size if size is not None else {"longest_edge": 24}
|
||||||
|
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.image_size = image_size
|
||||||
|
self.max_num_images_per_sample = max_num_images_per_sample
|
||||||
|
self.min_resolution = min_resolution
|
||||||
|
self.max_resolution = max_resolution
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
|
||||||
|
def prepare_image_processor_dict(self):
|
||||||
|
return {
|
||||||
|
"do_resize": self.do_resize,
|
||||||
|
"size": self.size,
|
||||||
|
"patch_size": self.patch_size,
|
||||||
|
"do_normalize": self.do_normalize,
|
||||||
|
"image_mean": self.image_mean,
|
||||||
|
"image_std": self.image_std,
|
||||||
|
"do_convert_rgb": self.do_convert_rgb,
|
||||||
|
}
|
||||||
|
|
||||||
|
def expected_output_image_shape(self, image):
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
width, height = image.size
|
||||||
|
elif isinstance(image, np.ndarray):
|
||||||
|
height, width = image.shape[:2]
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
|
||||||
|
max_height = max_width = self.size.get("longest_edge")
|
||||||
|
|
||||||
|
ratio = max(height / max_height, width / max_width)
|
||||||
|
if ratio > 1:
|
||||||
|
height = int(np.ceil(height / ratio))
|
||||||
|
width = int(np.ceil(width / ratio))
|
||||||
|
|
||||||
|
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||||
|
num_height_tokens = (height - 1) // patch_height + 1
|
||||||
|
num_width_tokens = (width - 1) // patch_width + 1
|
||||||
|
|
||||||
|
height = num_height_tokens * patch_height
|
||||||
|
width = num_width_tokens * patch_width
|
||||||
|
|
||||||
|
return self.num_channels, height, width
|
||||||
|
|
||||||
|
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||||
|
# Use prepare_image_inputs to make a list of list of single images
|
||||||
|
|
||||||
|
images_list = []
|
||||||
|
for _ in range(self.batch_size):
|
||||||
|
images = []
|
||||||
|
for _ in range(random.randint(1, self.max_num_images_per_sample)):
|
||||||
|
img = prepare_image_inputs(
|
||||||
|
batch_size=1,
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
min_resolution=self.min_resolution,
|
||||||
|
max_resolution=self.max_resolution,
|
||||||
|
equal_resolution=equal_resolution,
|
||||||
|
numpify=numpify,
|
||||||
|
torchify=torchify,
|
||||||
|
)[0]
|
||||||
|
images.append(img)
|
||||||
|
images_list.append(images)
|
||||||
|
return images_list
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
|
image_processing_class = PixtralImageProcessor if is_vision_available() else None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.image_processor_tester = PixtralImageProcessingTester(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_processor_dict(self):
|
||||||
|
return self.image_processor_tester.prepare_image_processor_dict()
|
||||||
|
|
||||||
|
def test_image_processor_properties(self):
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "size"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "patch_size"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||||
|
|
||||||
|
def test_call_pil(self):
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random PIL images
|
||||||
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
|
||||||
|
for image_inputs in image_inputs_list:
|
||||||
|
for image in image_inputs:
|
||||||
|
self.assertIsInstance(image, Image.Image)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||||
|
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||||
|
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||||
|
for encoded_image, image in zip(encoded_images, images):
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||||
|
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
def test_call_numpy(self):
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random numpy tensors
|
||||||
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
|
||||||
|
for image_inputs in image_inputs_list:
|
||||||
|
for image in image_inputs:
|
||||||
|
self.assertIsInstance(image, np.ndarray)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||||
|
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||||
|
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||||
|
for encoded_image, image in zip(encoded_images, images):
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||||
|
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
def test_call_pytorch(self):
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random PyTorch tensors
|
||||||
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||||
|
for image_inputs in image_inputs_list:
|
||||||
|
for image in image_inputs:
|
||||||
|
self.assertIsInstance(image, torch.Tensor)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||||
|
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||||
|
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||||
|
for encoded_image, image in zip(encoded_images, images):
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||||
|
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||||
|
def test_call_numpy_4_channels(self):
|
||||||
|
pass
|
||||||
292
tests/models/pixtral/test_modeling_pixtral.py
Normal file
292
tests/models/pixtral/test_modeling_pixtral.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch Pixtral model."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
PixtralModel,
|
||||||
|
PixtralVisionConfig,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_2_0 = False
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=12,
|
||||||
|
image_size=30,
|
||||||
|
patch_size=2,
|
||||||
|
num_channels=3,
|
||||||
|
is_training=True,
|
||||||
|
hidden_size=32,
|
||||||
|
projection_dim=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
initializer_range=0.02,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.is_training = is_training
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
self.seq_length = num_patches + 1
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return PixtralVisionConfig(
|
||||||
|
image_size=self.image_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
projection_dim=self.projection_dim,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=self.attention_dropout,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = PixtralModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values)
|
||||||
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||||
|
image_size = (self.image_size, self.image_size)
|
||||||
|
patch_size = (self.patch_size, self.patch_size)
|
||||||
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||||
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_model_with_projection(self, config, pixel_values):
|
||||||
|
model = PixtralModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values)
|
||||||
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||||
|
image_size = (self.image_size, self.image_size)
|
||||||
|
patch_size = (self.patch_size, self.patch_size)
|
||||||
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||||
|
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Model tester for `PixtralModel`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_model_classes = (PixtralModel,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = PixtralModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
@unittest.skip("model does not support input embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("model does not support input embeds")
|
||||||
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
|
||||||
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_cpu_offload(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_batching_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_disk_offload_bin(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_model_parallelism(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_save_load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_model_main_input_name(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_initialization(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_gradient_checkpointing_backward_compatibility(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_disk_offload_safetensors(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Not supported yet")
|
||||||
|
def test_determinism(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class PixtralModelIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_small_model_integration_test(self):
|
||||||
|
# Let' s make sure we test the preprocessing to replace what is used
|
||||||
|
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||||
|
|
||||||
|
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||||
|
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
|
||||||
|
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||||
|
inputs = self.processor(prompt, raw_image, return_tensors="pt")
|
||||||
|
|
||||||
|
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
|
||||||
|
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.processor.decode(output[0], skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
233
tests/models/pixtral/test_processor_pixtral.py
Normal file
233
tests/models/pixtral/test_processor_pixtral.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.testing_utils import require_vision
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
class PixtralProcessorTest(unittest.TestCase):
|
||||||
|
processor_class = PixtralProcessor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
|
||||||
|
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
|
||||||
|
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||||
|
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# FIXME - just load the processor directly from the checkpoint
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||||
|
image_processor = PixtralImageProcessor()
|
||||||
|
self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
|
||||||
|
@unittest.skip("No chat template was set for this model (yet)")
|
||||||
|
def test_chat_template(self):
|
||||||
|
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
self.assertEqual(expected_prompt, formatted_prompt)
|
||||||
|
|
||||||
|
@unittest.skip("No chat template was set for this model (yet)")
|
||||||
|
def test_image_token_filling(self):
|
||||||
|
# Important to check with non square image
|
||||||
|
image = torch.randint(0, 2, (3, 500, 316))
|
||||||
|
expected_image_tokens = 1526
|
||||||
|
image_token_index = 32000
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[self.processor.apply_chat_template(messages)],
|
||||||
|
images=[image],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||||
|
self.assertEqual(expected_image_tokens, image_tokens)
|
||||||
|
|
||||||
|
def test_processor_with_single_image(self):
|
||||||
|
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
self.processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||||
|
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||||
|
self.assertTrue(len(inputs_image["pixel_values"][0]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_image["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in a url
|
||||||
|
inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_url)
|
||||||
|
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||||
|
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||||
|
self.assertTrue(len(inputs_url["pixel_values"][0]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_url["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_processor_with_multiple_images_single_list(self):
|
||||||
|
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
self.processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||||
|
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||||
|
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_image["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in a url
|
||||||
|
inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
|
||||||
|
self.assertIn("input_ids", inputs_url)
|
||||||
|
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||||
|
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||||
|
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_url["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_processor_with_multiple_images_multiple_lists(self):
|
||||||
|
prompt_string = [
|
||||||
|
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
|
||||||
|
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||||
|
]
|
||||||
|
self.processor.tokenizer.pad_token = "</s>"
|
||||||
|
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
|
||||||
|
|
||||||
|
# Make small for checking image token expansion
|
||||||
|
self.processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||||
|
|
||||||
|
# Test passing in an image
|
||||||
|
inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||||
|
self.assertIn("input_ids", inputs_image)
|
||||||
|
self.assertTrue(len(inputs_image["input_ids"]) == 2)
|
||||||
|
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||||
|
self.assertTrue(len(inputs_image["pixel_values"]) == 2)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||||
|
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||||
|
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_image["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test passing in a url
|
||||||
|
inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||||
|
self.assertIn("input_ids", inputs_url)
|
||||||
|
self.assertTrue(len(inputs_url["input_ids"]) == 2)
|
||||||
|
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||||
|
self.assertTrue(len(inputs_url["pixel_values"]) == 2)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||||
|
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
|
||||||
|
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_ids = inputs_url["input_ids"]
|
||||||
|
self.assertEqual(
|
||||||
|
input_ids[0].tolist(),
|
||||||
|
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||||
|
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
Reference in New Issue
Block a user