Add support for Pixtral (#33449)
* initial commit * gloups * updates * work * weights match * nits * nits * updates to support the tokenizer :) * updates * Pixtral processor (#33454) * rough outline * Add in image break and end tokens * Fix * Udo some formatting changes * Set patch_size default * Fix * Fix token expansion * nit in conversion script * Fix image token list creation * done * add expected results * Process list of list of images (#33465) * updates * working image and processor * this is the expected format * some fixes * push current updated * working mult images! * add a small integration test * Uodate configuration docstring * Formatting * Config docstring fix * simplify model test * fixup modeling and etests * Return BatchMixFeature in image processor * fix some copies * update * nits * Update model docstring * Apply suggestions from code review * Fix up * updates * revert modeling changes * update * update * fix load safe * addd liscence * update * use pixel_values as required by the model * skip some tests and refactor * Add pixtral image processing tests (#33476) * Image processing tests * Add processing tests * woops * defaults reflect pixtral image processor * fixup post merge * images -> pixel values * oups sorry Mr docbuilder * isort * fix * fix processor tests * small fixes * nit * update * last nits * oups this was really breaking! * nits * is composition needs to be true --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -862,6 +862,8 @@
|
||||
title: Perceiver
|
||||
- local: model_doc/pix2struct
|
||||
title: Pix2Struct
|
||||
- local: model_doc/pixtral
|
||||
title: Pixtral
|
||||
- local: model_doc/sam
|
||||
title: Segment Anything
|
||||
- local: model_doc/siglip
|
||||
|
||||
@@ -253,6 +253,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
|
||||
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
|
||||
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
|
||||
| [Pixtral](model_doc/pixtral) | ❌ | ❌ | ❌ |
|
||||
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
|
||||
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
|
||||
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |
|
||||
|
||||
98
docs/source/en/model_doc/pixtral.md
Normal file
98
docs/source/en/model_doc/pixtral.md
Normal file
@@ -0,0 +1,98 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Pixtral
|
||||
|
||||
## Overview
|
||||
|
||||
The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!
|
||||
|
||||
|
||||
Tips:
|
||||
|
||||
- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
|
||||
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
|
||||
- The format for one or mulitple prompts is the following:
|
||||
```
|
||||
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||
```
|
||||
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.
|
||||
|
||||
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)
|
||||
|
||||
Here is an example of how to run it:
|
||||
|
||||
```python
|
||||
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
||||
from PIL import Image
|
||||
|
||||
model_id = "hf-internal-testing/pixtral-12b"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cuda")
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
IMG_URLS = [
|
||||
"https://picsum.photos/id/237/400/300",
|
||||
"https://picsum.photos/id/231/200/300",
|
||||
"https://picsum.photos/id/27/500/500",
|
||||
"https://picsum.photos/id/17/150/600",
|
||||
]
|
||||
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
|
||||
|
||||
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=500)
|
||||
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
EXPECTED_GENERATION = """
|
||||
Describe the images.
|
||||
Sure, let's break down each image description:
|
||||
|
||||
1. **Image 1:**
|
||||
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
|
||||
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
|
||||
|
||||
2. **Image 2:**
|
||||
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
|
||||
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
|
||||
|
||||
3. **Image 3:**
|
||||
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
|
||||
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
|
||||
|
||||
4. **Image 4:**
|
||||
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
|
||||
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
|
||||
|
||||
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
|
||||
"""
|
||||
|
||||
```
|
||||
## PixtralVisionConfig
|
||||
|
||||
[[autodoc]] PixtralVisionConfig
|
||||
|
||||
## PixtralModel
|
||||
|
||||
[[autodoc]] PixtralModel
|
||||
- forward
|
||||
|
||||
## PixtralImageProcessor
|
||||
|
||||
[[autodoc]] PixtralImageProcessor
|
||||
- preprocess
|
||||
|
||||
## PixtralProcessor
|
||||
|
||||
[[autodoc]] PixtralProcessor
|
||||
@@ -649,6 +649,7 @@ _import_structure = {
|
||||
"Pix2StructTextConfig",
|
||||
"Pix2StructVisionConfig",
|
||||
],
|
||||
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
|
||||
"models.plbart": ["PLBartConfig"],
|
||||
"models.poolformer": ["PoolFormerConfig"],
|
||||
"models.pop2piano": ["Pop2PianoConfig"],
|
||||
@@ -1199,6 +1200,7 @@ else:
|
||||
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
|
||||
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
|
||||
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
|
||||
_import_structure["models.pixtral"].append("PixtralImageProcessor")
|
||||
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
|
||||
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
|
||||
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
|
||||
@@ -1359,7 +1361,6 @@ else:
|
||||
"AlignVisionModel",
|
||||
]
|
||||
)
|
||||
|
||||
_import_structure["models.altclip"].extend(
|
||||
[
|
||||
"AltCLIPModel",
|
||||
@@ -2977,6 +2978,7 @@ else:
|
||||
"Pix2StructVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
|
||||
_import_structure["models.plbart"].extend(
|
||||
[
|
||||
"PLBartForCausalLM",
|
||||
@@ -5434,6 +5436,10 @@ if TYPE_CHECKING:
|
||||
Pix2StructTextConfig,
|
||||
Pix2StructVisionConfig,
|
||||
)
|
||||
from .models.pixtral import (
|
||||
PixtralProcessor,
|
||||
PixtralVisionConfig,
|
||||
)
|
||||
from .models.plbart import PLBartConfig
|
||||
from .models.poolformer import (
|
||||
PoolFormerConfig,
|
||||
@@ -6009,6 +6015,7 @@ if TYPE_CHECKING:
|
||||
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
|
||||
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
|
||||
from .models.pix2struct import Pix2StructImageProcessor
|
||||
from .models.pixtral import PixtralImageProcessor
|
||||
from .models.poolformer import (
|
||||
PoolFormerFeatureExtractor,
|
||||
PoolFormerImageProcessor,
|
||||
@@ -7448,6 +7455,10 @@ if TYPE_CHECKING:
|
||||
Pix2StructTextModel,
|
||||
Pix2StructVisionModel,
|
||||
)
|
||||
from .models.pixtral import (
|
||||
PixtralModel,
|
||||
PixtralPreTrainedModel,
|
||||
)
|
||||
from .models.plbart import (
|
||||
PLBartForCausalLM,
|
||||
PLBartForConditionalGeneration,
|
||||
|
||||
@@ -187,6 +187,7 @@ from . import (
|
||||
phi3,
|
||||
phobert,
|
||||
pix2struct,
|
||||
pixtral,
|
||||
plbart,
|
||||
poolformer,
|
||||
pop2piano,
|
||||
|
||||
@@ -205,6 +205,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("phi", "PhiConfig"),
|
||||
("phi3", "Phi3Config"),
|
||||
("pix2struct", "Pix2StructConfig"),
|
||||
("pixtral", "PixtralVisionConfig"),
|
||||
("plbart", "PLBartConfig"),
|
||||
("poolformer", "PoolFormerConfig"),
|
||||
("pop2piano", "Pop2PianoConfig"),
|
||||
@@ -509,6 +510,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("phi3", "Phi3"),
|
||||
("phobert", "PhoBERT"),
|
||||
("pix2struct", "Pix2Struct"),
|
||||
("pixtral", "Pixtral"),
|
||||
("plbart", "PLBart"),
|
||||
("poolformer", "PoolFormer"),
|
||||
("pop2piano", "Pop2Piano"),
|
||||
|
||||
@@ -114,6 +114,7 @@ else:
|
||||
("owlvit", ("OwlViTImageProcessor",)),
|
||||
("perceiver", ("PerceiverImageProcessor",)),
|
||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor",)),
|
||||
("poolformer", ("PoolFormerImageProcessor",)),
|
||||
("pvt", ("PvtImageProcessor",)),
|
||||
("pvt_v2", ("PvtImageProcessor",)),
|
||||
|
||||
@@ -193,6 +193,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("persimmon", "PersimmonModel"),
|
||||
("phi", "PhiModel"),
|
||||
("phi3", "Phi3Model"),
|
||||
("pixtral", "PixtralModel"),
|
||||
("plbart", "PLBartModel"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("prophetnet", "ProphetNetModel"),
|
||||
|
||||
@@ -82,6 +82,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
("paligemma", "PaliGemmaProcessor"),
|
||||
("pix2struct", "Pix2StructProcessor"),
|
||||
("pixtral", "PixtralProcessor"),
|
||||
("pop2piano", "Pop2PianoProcessor"),
|
||||
("qwen2_audio", "Qwen2AudioProcessor"),
|
||||
("qwen2_vl", "Qwen2VLProcessor"),
|
||||
|
||||
@@ -385,6 +385,7 @@ else:
|
||||
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
||||
@@ -73,7 +73,7 @@ class LlavaConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llava"
|
||||
is_composition = False
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
70
src/transformers/models/pixtral/__init__.py
Normal file
70
src/transformers/models/pixtral/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_pixtral": ["PixtralVisionConfig"],
|
||||
"processing_pixtral": ["PixtralProcessor"],
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_pixtral"] = [
|
||||
"PixtralModel",
|
||||
"PixtralPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_pixtral import (
|
||||
PixtralModel,
|
||||
PixtralPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_pixtral import PixtralImageProcessor
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
103
src/transformers/models/pixtral/configuration_pixtral.py
Normal file
103
src/transformers/models/pixtral/configuration_pixtral.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pixtral model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PixtralVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an
|
||||
Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Pixtral-9B.
|
||||
|
||||
e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of input channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 1024):
|
||||
Max dimension of the input images.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
Size of the image patches.
|
||||
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
Activation function used in the hidden layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability for the attention layers.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie the word embeddings with the input embeddings.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
|
||||
|
||||
>>> # Initializing a Pixtral 12B style configuration
|
||||
>>> config = PixtralVisionConfig()
|
||||
|
||||
>>> # Initializing a model from the pixtral 12B style configuration
|
||||
>>> model = PixtralModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "pixtral"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
image_size=1024,
|
||||
patch_size=16,
|
||||
hidden_act="gelu",
|
||||
attention_dropout=0.0,
|
||||
rope_theta=10000.0,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.rope_theta = rope_theta
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
285
src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py
Normal file
285
src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE
|
||||
|
||||
from transformers import (
|
||||
LlavaConfig,
|
||||
LlavaForConditionalGeneration,
|
||||
MistralConfig,
|
||||
PixtralImageProcessor,
|
||||
PixtralProcessor,
|
||||
PixtralVisionConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
|
||||
"""
|
||||
# Here is how to get the original tokens!
|
||||
model_name = "mistralai/Pixtral-12B-2409"
|
||||
tok = MistralTokenizer.from_model(model_name)
|
||||
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest, UserMessage, ImageChunk, TextChunk
|
||||
|
||||
EXPECTED_TOKENS = tok.encode_chat_completion(
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
TextChunk(text="Describe the images"),
|
||||
] + [ImageChunk(image=img) for img in IMG_URLS]
|
||||
)
|
||||
],
|
||||
model="pixtral",
|
||||
)
|
||||
)
|
||||
assert tokenizer.decode(inputs["input_ids"][0]) == EXPECTED_TOKENS
|
||||
"""
|
||||
|
||||
OLD_KEY_TO_NEW_KEY_MAPPING = {
|
||||
# Layer Normalization Weights
|
||||
r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight",
|
||||
r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight",
|
||||
# Self Attention Projections
|
||||
r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight",
|
||||
r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight",
|
||||
r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight",
|
||||
r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight",
|
||||
# MLP Projections
|
||||
r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
|
||||
r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
|
||||
r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
|
||||
# Additional mappings
|
||||
r"vision_encoder": r"vision_tower",
|
||||
r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1",
|
||||
r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2",
|
||||
r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight",
|
||||
r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight",
|
||||
r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight",
|
||||
r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight",
|
||||
r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
|
||||
r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
|
||||
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
|
||||
r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||
r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
|
||||
r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
|
||||
r"output.weight": r"language_model.lm_head.weight",
|
||||
r"norm.weight": r"language_model.model.norm.weight",
|
||||
}
|
||||
|
||||
|
||||
class MistralConverter:
|
||||
"""
|
||||
A general tiktoken converter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab=None,
|
||||
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||
add_prefix_space=False,
|
||||
additional_special_tokens=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args)
|
||||
self.vocab = vocab
|
||||
self.pattern = pattern
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.additional_special_tokens = additional_special_tokens
|
||||
|
||||
def extract_vocab_merges_from_model(self, vocab: str):
|
||||
bpe_ranks = vocab
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
def token_bytes_to_string(b):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
for idx, (token, rank) in enumerate(bpe_ranks.items()):
|
||||
if token not in self.additional_special_tokens:
|
||||
vocab[token_bytes_to_string(token)] = idx
|
||||
if len(token) == 1:
|
||||
continue
|
||||
local = []
|
||||
for index in range(1, len(token)):
|
||||
piece_l, piece_r = token[:index], token[index:]
|
||||
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
|
||||
local.append((piece_l, piece_r, rank))
|
||||
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
|
||||
merges.extend(local)
|
||||
else:
|
||||
vocab[token] = idx
|
||||
merges = sorted(merges, key=lambda val: val[2], reverse=False)
|
||||
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
|
||||
return vocab, merges
|
||||
|
||||
def tokenizer(self):
|
||||
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
|
||||
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
|
||||
if hasattr(tokenizer.model, "ignore_merges"):
|
||||
tokenizer.model.ignore_merges = True
|
||||
return tokenizer
|
||||
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer = self.tokenizer()
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
|
||||
]
|
||||
)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
tokenizer.add_special_tokens(self.additional_special_tokens)
|
||||
|
||||
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def convert_mistral_tokenizer():
|
||||
model_name = "mistralai/Pixtral-12B-2409"
|
||||
|
||||
tokenizer = MistralTokenizer.from_model(model_name)
|
||||
|
||||
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
|
||||
all_special = [
|
||||
token.value if hasattr(token, "value") else token
|
||||
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
|
||||
]
|
||||
specials_tokens = {token: all_special.index(token) for token in all_special}
|
||||
specials_tokens.update(vocab)
|
||||
vocab = specials_tokens
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def permute_for_rope(value, n_heads, config):
|
||||
dim1 = value.shape[0]
|
||||
dim2 = config.hidden_size
|
||||
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
|
||||
def convert_dictionnary(original_state_dict, vision_config, text_config):
|
||||
new_dict = {}
|
||||
|
||||
all_keys = "\n" + "\n".join(original_state_dict.keys())
|
||||
old_keys = all_keys
|
||||
for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items():
|
||||
all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys)
|
||||
|
||||
OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n")))
|
||||
|
||||
for key, value in original_state_dict.items():
|
||||
new_key = OLD_TO_NEW[key]
|
||||
if "vision_encoder" in key:
|
||||
_config = vision_config
|
||||
num_attention_heads = _config.num_attention_heads
|
||||
else:
|
||||
_config = text_config
|
||||
if "q_proj" in new_key:
|
||||
num_attention_heads = _config.num_attention_heads
|
||||
if "k_proj" in new_key:
|
||||
num_attention_heads = _config.num_key_value_heads
|
||||
# convert the text model (basically mistral model)
|
||||
|
||||
if "q_proj" in new_key or "k_proj" in new_key:
|
||||
value = permute_for_rope(value, num_attention_heads, _config)
|
||||
|
||||
new_dict[new_key] = value
|
||||
return new_dict
|
||||
|
||||
|
||||
def convert_mistral_model(input_dir, output_dir):
|
||||
text_config = MistralConfig(
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
hidden_size=5120,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=14336,
|
||||
max_position_embeddings=1024000,
|
||||
model_type="mistral",
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=8,
|
||||
rms_norm_eps=1e-05,
|
||||
rope_theta=1000000000.0,
|
||||
sliding_window=None,
|
||||
tie_word_embeddings=False,
|
||||
vocab_size=131072,
|
||||
)
|
||||
|
||||
vision_config = PixtralVisionConfig()
|
||||
config = LlavaConfig(
|
||||
vision_config,
|
||||
text_config,
|
||||
vision_feature_layer=-1,
|
||||
image_token_index=10,
|
||||
vision_feature_select_strategy="full",
|
||||
image_seq_length=1,
|
||||
)
|
||||
config.architectures = ["LlavaForConditionalGeneration"]
|
||||
config.save_pretrained(output_dir)
|
||||
|
||||
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
|
||||
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = LlavaForConditionalGeneration(config)
|
||||
model.load_state_dict(new_dict, strict=True, assign=True)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
tokenizer = convert_mistral_tokenizer()
|
||||
image_processor = PixtralImageProcessor()
|
||||
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_mistral_model(args.input_dir, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
519
src/transformers/models/pixtral/image_processing_pixtral.py
Normal file
519
src/transformers/models/pixtral/image_processing_pixtral.py
Normal file
@@ -0,0 +1,519 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for Pixtral."""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
|
||||
from ...utils.import_utils import requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
class BatchMixFeature(BatchFeature):
|
||||
def to(self, *args, **kwargs) -> "BatchMixFeature":
|
||||
"""
|
||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||
|
||||
Args:
|
||||
args (`Tuple`):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
kwargs (`Dict`, *optional*):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
# device should be always the first argument
|
||||
arg = args[0]
|
||||
if is_torch_dtype(arg):
|
||||
# The first argument is a dtype
|
||||
pass
|
||||
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||
device = arg
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
return self
|
||||
|
||||
|
||||
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
|
||||
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
||||
"""
|
||||
Convert a single image or a list of images to a list of numpy arrays.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
A single image or a list of images.
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays.
|
||||
"""
|
||||
# If it's a single image, convert it to a list of lists
|
||||
if is_valid_image(images):
|
||||
images = [[images]]
|
||||
# If it's a list of images, it's a single batch, so convert it to a list of lists
|
||||
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
|
||||
images = [images]
|
||||
# If it's a list of batches, it's already in the right format
|
||||
elif (
|
||||
isinstance(images, (list, tuple))
|
||||
and len(images) > 0
|
||||
and isinstance(images[0], (list, tuple))
|
||||
and is_valid_image(images[0][0])
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
|
||||
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||
"""
|
||||
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
||||
as is.
|
||||
Args:
|
||||
image (Image):
|
||||
The image to convert.
|
||||
"""
|
||||
requires_backends(convert_to_rgb, ["vision"])
|
||||
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
if image.mode == "RGB":
|
||||
return image
|
||||
|
||||
# First we convert to RGBA to set background to white.
|
||||
image = image.convert("RGBA")
|
||||
|
||||
# Create a new image with a white background.
|
||||
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
|
||||
new_image.paste(image, (0, 0), image)
|
||||
new_image = new_image.convert("RGB")
|
||||
return new_image
|
||||
|
||||
|
||||
def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int:
|
||||
"""
|
||||
Calculate the number of image tokens given the image size and patch size.
|
||||
|
||||
Args:
|
||||
image_size (`Tuple[int, int]`):
|
||||
The size of the image as `(height, width)`.
|
||||
patch_size (`Tuple[int, int]`):
|
||||
The patch size as `(height, width)`.
|
||||
|
||||
Returns:
|
||||
`int`: The number of image tokens.
|
||||
"""
|
||||
height, width = image_size
|
||||
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
|
||||
num_width_tokens = (width - 1) // patch_width + 1
|
||||
num_height_tokens = (height - 1) // patch_height + 1
|
||||
return num_height_tokens, num_width_tokens
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
||||
size.
|
||||
|
||||
Args:
|
||||
input_image (`np.ndarray`):
|
||||
The image to resize.
|
||||
size (`int` or `Tuple[int, int]`):
|
||||
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
|
||||
patch_size (`int` or `Tuple[int, int]`):
|
||||
The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)`
|
||||
will be used
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||
|
||||
Returns:
|
||||
`tuple`: The target (height, width) dimension of the output image after resizing.
|
||||
"""
|
||||
max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size)
|
||||
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
|
||||
height, width = get_image_size(input_image, input_data_format)
|
||||
|
||||
ratio = max(height / max_height, width / max_width)
|
||||
|
||||
if ratio > 1:
|
||||
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
||||
height = int(np.ceil(height / ratio))
|
||||
width = int(np.ceil(width / ratio))
|
||||
|
||||
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
||||
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
||||
|
||||
|
||||
# Hack to get tensor conversion used in BatchFeature without batching the images
|
||||
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
|
||||
return BatchFeature()._get_is_as_tensor_fns(tensor_type)
|
||||
|
||||
|
||||
def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
|
||||
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
|
||||
if is_tensor(array):
|
||||
return array
|
||||
return as_tensor(array)
|
||||
|
||||
|
||||
class PixtralImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a Pixtral image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
|
||||
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
|
||||
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
|
||||
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
|
||||
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
patch_size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"longest_edge": 1024}
|
||||
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.patch_size = patch_size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self._valid_processor_keys = [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"patch_size",
|
||||
"resample",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
]
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
patch_size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||
resized to keep the input aspect ratio.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dict containing the longest possible edge of the image.
|
||||
patch_size (`Dict[str, int]`):
|
||||
Patch size used to calculate the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "longest_edge" in size:
|
||||
size = (size["longest_edge"], size["longest_edge"])
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
|
||||
|
||||
if "height" in patch_size and "width" in patch_size:
|
||||
patch_size = (patch_size["height"], patch_size["width"])
|
||||
else:
|
||||
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||||
|
||||
output_size = get_resize_output_image_size(
|
||||
image,
|
||||
size=size,
|
||||
patch_size=patch_size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
patch_size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Describes the maximum input dimensions to the model.
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||
Patch size in the model. Used to calculate the image after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||
|
||||
images_list = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images_list[0]):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||
|
||||
if is_scaled_image(images_list[0][0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||
|
||||
batch_images = []
|
||||
batch_image_sizes = []
|
||||
for sample_images in images_list:
|
||||
images = []
|
||||
image_sizes = []
|
||||
for image in sample_images:
|
||||
if do_resize:
|
||||
image = self.resize(
|
||||
image=image,
|
||||
size=size,
|
||||
patch_size=patch_size,
|
||||
resample=resample,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
images.append(image)
|
||||
image_sizes.append(get_image_size(image, input_data_format))
|
||||
batch_images.append(images)
|
||||
batch_image_sizes.append(image_sizes)
|
||||
|
||||
images_list = [
|
||||
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
|
||||
for images in batch_images
|
||||
]
|
||||
|
||||
# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
|
||||
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
|
||||
return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)
|
||||
517
src/transformers/models/pixtral/modeling_pixtral.py
Normal file
517
src/transformers/models/pixtral/modeling_pixtral.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Pixtral model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
)
|
||||
from .configuration_pixtral import PixtralVisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def position_ids_in_meshgrid(patch_embeds_list, max_width):
|
||||
positions = []
|
||||
for patch in patch_embeds_list:
|
||||
height, width = patch.shape[-2:]
|
||||
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
||||
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
|
||||
ids = h_grid * max_width + v_grid
|
||||
positions.append(ids[:, 0])
|
||||
return torch.cat(positions)
|
||||
|
||||
|
||||
class PixtralRotaryEmbedding(nn.Module):
|
||||
"""
|
||||
The key with pixtral embedding is just that you have a frequency for each pixel positions.
|
||||
If you have height x width pixels (or embedding pixels)
|
||||
|
||||
then the frequency used for ROPE is given by indexing the pre_computed frequency on the
|
||||
width and height.
|
||||
|
||||
What you output is of dimension batch, height * width, dim with dim the embed dim.
|
||||
|
||||
This simply means that for each image hidden states, you are going to add
|
||||
a corresponding positional embedding, based on it's index in the grid.
|
||||
"""
|
||||
|
||||
def __init__(self, config, device):
|
||||
super().__init__()
|
||||
self.rope_type = "default"
|
||||
self.dim = config.head_dim
|
||||
self.base = config.rope_theta
|
||||
max_patches_per_side = config.image_size // config.patch_size
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
|
||||
h = torch.arange(max_patches_per_side, device=freqs.device)
|
||||
w = torch.arange(max_patches_per_side, device=freqs.device)
|
||||
|
||||
freqs_h = torch.outer(h, freqs[::2]).float()
|
||||
freqs_w = torch.outer(w, freqs[1::2]).float()
|
||||
inv_freq = torch.cat(
|
||||
[
|
||||
freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
|
||||
freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
|
||||
],
|
||||
dim=-1,
|
||||
).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
|
||||
# TODO maybe make it torch compatible later on. We can also just slice
|
||||
self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block
|
||||
freqs = self.inv_freq[position_ids]
|
||||
# position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
emb = freqs
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class PixtralAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
|
||||
class PixtralMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
|
||||
class PixtralRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
PixtralRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class PixtralAttentionLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.feed_forward = PixtralMLP(config)
|
||||
self.attention = PixtralAttention(config)
|
||||
self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||
attention_mask (`torch.FloatTensor`):
|
||||
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.attention_norm(hidden_states)
|
||||
hidden_states, attn_weights = self.attention(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ffn_norm(hidden_states)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
return outputs
|
||||
|
||||
|
||||
class PixtralTransformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(config.num_hidden_layers):
|
||||
self.layers.append(PixtralAttentionLayer(config))
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
PIXTRAL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
PIXTRAL_START_DOCSTRING,
|
||||
)
|
||||
class PixtralPreTrainedModel(PreTrainedModel):
|
||||
config_class = PixtralVisionConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PixtralVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Pixtral isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
PIXTRAL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values: list of N_img images of variable sizes,
|
||||
each of shape (C, H, W)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
def generate_block_attention_mask(patch_embeds_list, tensor):
|
||||
dtype = tensor.dtype
|
||||
device = tensor.device
|
||||
seq_len = tensor.shape[1]
|
||||
d_min = torch.finfo(dtype).min
|
||||
causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
|
||||
|
||||
block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
|
||||
block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
|
||||
for start, end in zip(block_start_idx, block_end_idx):
|
||||
causal_mask[start:end, start:end] = 0
|
||||
|
||||
causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
|
||||
return causal_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The PIXTRAL model which consists of a vision backbone and a language model.""",
|
||||
PIXTRAL_START_DOCSTRING,
|
||||
)
|
||||
class PixtralModel(PixtralPreTrainedModel):
|
||||
base_model_prefix = "vision_encoder"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.patch_conv = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.patch_size,
|
||||
stride=config.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.transformer = PixtralTransformer(config)
|
||||
self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device)
|
||||
|
||||
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: List[torch.Tensor],
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
"""
|
||||
Returns:
|
||||
pixel_values: tensor of token features for
|
||||
all tokens of all images of shape (N_toks, D)
|
||||
"""
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
|
||||
|
||||
# flatten to a single sequence
|
||||
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
||||
patch_embeds = self.ln_pre(patch_embeds)
|
||||
|
||||
# positional embeddings
|
||||
position_ids = position_ids_in_meshgrid(
|
||||
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
||||
).to(self.device)
|
||||
|
||||
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
||||
attention_mask = generate_block_attention_mask(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||
)
|
||||
return self.transformer(patch_embeds, attention_mask, position_embedding)
|
||||
282
src/transformers/models/pixtral/processing_pixtral.py
Normal file
282
src/transformers/models/pixtral/processing_pixtral.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Pixtral.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.idefics2.processing_idefics2.is_url
|
||||
def is_url(val) -> bool:
|
||||
return isinstance(val, str) and val.startswith("http")
|
||||
|
||||
|
||||
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
|
||||
def is_image_or_image_url(elem):
|
||||
return is_url(elem) or is_valid_image(elem)
|
||||
|
||||
|
||||
# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
|
||||
class BatchMixFeature(BatchFeature):
|
||||
def to(self, *args, **kwargs) -> "BatchMixFeature":
|
||||
"""
|
||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||
|
||||
Args:
|
||||
args (`Tuple`):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
kwargs (`Dict`, *optional*):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
# device should be always the first argument
|
||||
arg = args[0]
|
||||
if is_torch_dtype(arg):
|
||||
# The first argument is a dtype
|
||||
pass
|
||||
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||
device = arg
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
return self
|
||||
|
||||
|
||||
class PixtralProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor.
|
||||
|
||||
[`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
||||
[`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`PixtralImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
Patch size from the vision tower.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"[IMG]"`):
|
||||
Special token used to denote image location.
|
||||
image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`):
|
||||
Special token used to denote the end of a line of pixels in an image.
|
||||
image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`):
|
||||
Special token used to denote the end of an image input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"image_token",
|
||||
"image_break_token",
|
||||
"image_end_token",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size: int = 16,
|
||||
chat_template=None,
|
||||
image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
image_break_token="[IMG_BREAK]",
|
||||
image_end_token="[IMG_END]",
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.image_token = image_token
|
||||
self.image_break_token = image_break_token
|
||||
self.image_end_token = image_end_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
images: ImageInput = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length=None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
||||
) -> BatchMixFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (`bool`, *optional*):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
if images is not None:
|
||||
if is_image_or_image_url(images):
|
||||
images = [[images]]
|
||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||
images = [images]
|
||||
elif (
|
||||
not isinstance(images, list)
|
||||
and not isinstance(images[0], list)
|
||||
and not is_image_or_image_url(images[0][0])
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
||||
)
|
||||
images = [[load_image(im) for im in sample] for sample in images]
|
||||
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
# try to expand inputs in processing if we have the necessary parts
|
||||
prompt_strings = text
|
||||
if image_inputs.get("pixel_values") is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
images = image_inputs["pixel_values"]
|
||||
image_sizes = image_inputs.pop("image_sizes")
|
||||
prompt_strings = []
|
||||
|
||||
for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
|
||||
replace_strings = []
|
||||
# First calculate the number of tokens needed for each image and put in a placeholder
|
||||
for image, image_size in zip(sample_images, sample_image_sizes):
|
||||
height, width = image_size
|
||||
num_height_tokens = height // self.patch_size
|
||||
num_width_tokens = width // self.patch_size
|
||||
replace_tokens = [
|
||||
[self.image_token] * num_width_tokens + [self.image_break_token]
|
||||
] * num_height_tokens
|
||||
# Flatten list
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||
replace_tokens[-1] = self.image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
replace_strings.append(replace_str)
|
||||
sample = sample.replace(self.image_token, "<placeholder>", 1)
|
||||
|
||||
while "<placeholder>" in sample:
|
||||
replace_str = replace_strings.pop(0)
|
||||
sample = sample.replace("<placeholder>", replace_str, 1)
|
||||
|
||||
prompt_strings.append(sample)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt_strings,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
return BatchMixFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
@@ -7067,6 +7067,20 @@ class Pix2StructVisionModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PixtralModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PixtralPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PLBartForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -506,6 +506,13 @@ class Pix2StructImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class PixtralImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class PoolFormerFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
||||
@@ -569,3 +569,50 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_pixtral(self):
|
||||
model_id = "hf-internal-testing/pixtral-12b"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
IMG_URLS = [
|
||||
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
|
||||
Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw),
|
||||
Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw),
|
||||
Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw),
|
||||
]
|
||||
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
|
||||
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=500)
|
||||
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_GENERATION = """
|
||||
Describe the images.
|
||||
Sure, let's break down each image description:
|
||||
|
||||
1. **Image 1:**
|
||||
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
|
||||
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
|
||||
|
||||
2. **Image 2:**
|
||||
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
|
||||
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
|
||||
|
||||
3. **Image 3:**
|
||||
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
|
||||
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
|
||||
|
||||
4. **Image 4:**
|
||||
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
|
||||
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
|
||||
|
||||
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
|
||||
"""
|
||||
# fmt: on
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(ouptut, EXPECTED_GENERATION)
|
||||
|
||||
0
tests/models/pixtral/__init__.py
Normal file
0
tests/models/pixtral/__init__.py
Normal file
217
tests/models/pixtral/test_image_processing_pixtral.py
Normal file
217
tests/models/pixtral/test_image_processing_pixtral.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import PixtralImageProcessor
|
||||
|
||||
|
||||
class PixtralImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
max_num_images_per_sample=3,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
patch_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"longest_edge": 24}
|
||||
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.max_num_images_per_sample = max_num_images_per_sample
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.patch_size = patch_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"patch_size": self.patch_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, image):
|
||||
if isinstance(image, Image.Image):
|
||||
width, height = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
height, width = image.shape[:2]
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
max_height = max_width = self.size.get("longest_edge")
|
||||
|
||||
ratio = max(height / max_height, width / max_width)
|
||||
if ratio > 1:
|
||||
height = int(np.ceil(height / ratio))
|
||||
width = int(np.ceil(width / ratio))
|
||||
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
num_height_tokens = (height - 1) // patch_height + 1
|
||||
num_width_tokens = (width - 1) // patch_width + 1
|
||||
|
||||
height = num_height_tokens * patch_height
|
||||
width = num_width_tokens * patch_width
|
||||
|
||||
return self.num_channels, height, width
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
# Use prepare_image_inputs to make a list of list of single images
|
||||
|
||||
images_list = []
|
||||
for _ in range(self.batch_size):
|
||||
images = []
|
||||
for _ in range(random.randint(1, self.max_num_images_per_sample)):
|
||||
img = prepare_image_inputs(
|
||||
batch_size=1,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)[0]
|
||||
images.append(img)
|
||||
images_list.append(images)
|
||||
return images_list
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = PixtralImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = PixtralImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "patch_size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
|
||||
for image_inputs in image_inputs_list:
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||
for encoded_image, image in zip(encoded_images, images):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
|
||||
for image_inputs in image_inputs_list:
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||
for encoded_image, image in zip(encoded_images, images):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||
for image_inputs in image_inputs_list:
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||
for encoded_image, image in zip(encoded_images, images):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
|
||||
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
292
tests/models/pixtral/test_modeling_pixtral.py
Normal file
292
tests/models/pixtral/test_modeling_pixtral.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Pixtral model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
PixtralModel,
|
||||
PixtralVisionConfig,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PixtralModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return PixtralVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = PixtralModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_model_with_projection(self, config, pixel_values):
|
||||
model = PixtralModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `PixtralModel`.
|
||||
"""
|
||||
|
||||
all_model_classes = (PixtralModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PixtralModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
|
||||
|
||||
@unittest.skip("model does not support input embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("model does not support input embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in Pixtral models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_batching_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_model_main_input_name(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_gradient_checkpointing_backward_compatibility(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not supported yet")
|
||||
def test_determinism(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||
|
||||
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = self.processor(prompt, raw_image, return_tensors="pt")
|
||||
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
233
tests/models/pixtral/test_processor_pixtral.py
Normal file
233
tests/models/pixtral/test_processor_pixtral.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
class PixtralProcessorTest(unittest.TestCase):
|
||||
processor_class = PixtralProcessor
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
|
||||
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
|
||||
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# FIXME - just load the processor directly from the checkpoint
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||
image_processor = PixtralImageProcessor()
|
||||
self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
@unittest.skip("No chat template was set for this model (yet)")
|
||||
def test_chat_template(self):
|
||||
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
@unittest.skip("No chat template was set for this model (yet)")
|
||||
def test_image_token_filling(self):
|
||||
# Important to check with non square image
|
||||
image = torch.randint(0, 2, (3, 500, 316))
|
||||
expected_image_tokens = 1526
|
||||
image_token_index = 32000
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = self.processor(
|
||||
text=[self.processor.apply_chat_template(messages)],
|
||||
images=[image],
|
||||
return_tensors="pt",
|
||||
)
|
||||
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||
self.assertEqual(expected_image_tokens, image_tokens)
|
||||
|
||||
def test_processor_with_single_image(self):
|
||||
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
|
||||
|
||||
# Make small for checking image token expansion
|
||||
self.processor.image_processor.size = {"longest_edge": 30}
|
||||
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"][0]) == 1)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_with_multiple_images_single_list(self):
|
||||
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
|
||||
|
||||
# Make small for checking image token expansion
|
||||
self.processor.image_processor.size = {"longest_edge": 30}
|
||||
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_with_multiple_images_multiple_lists(self):
|
||||
prompt_string = [
|
||||
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
|
||||
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||
]
|
||||
self.processor.tokenizer.pad_token = "</s>"
|
||||
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
|
||||
|
||||
# Make small for checking image token expansion
|
||||
self.processor.image_processor.size = {"longest_edge": 30}
|
||||
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in a url
|
||||
inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"]) == 2)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
Reference in New Issue
Block a user