Add Ovis2 model and processor implementation (#37088)
* Add Ovis2 model and processor implementation * Apply style fixes * Add unit tests for Ovis2 image processing and processor * Refactor image processing functions for clarity and efficiency * Add Ovis2 ImageProcessorFast * Refactor Ovis2 code * Refactor Ovis2 model components and update processor functionality * Fix repo consistency issues for Ovis2: docstring, config cleanup * Update Ovis2 model integration tests * Update Ovis2 configuration and processing classes for improved documentation * Remove duplicate entry for 'ovis2' in VLM_CLASS_NAMES * Fix conflict * Fix import order * Update image processor class names * Update Ovis2 model structure * Refactor Ovis2 configuration * Fix typos * Refactor Ovis2 model classes and remove unused code * Fix typos * Refactor Ovis2 model initialization * Fiix typos * Remove Ovis2 model mapping from MODEL_MAPPING_NAMES in modeling_auto.py * Add license and update type hints * Refactor token function and update docstring handling * Add license * Add Ovis2 model support and update documentation * Refactor Ovis2 model structure and enhance multimodal capabilities * Update Ovis2 weight mapping for consistency and clarity in key patterns * Remove unused 'grids' parameter from Ovis2 model and Update processing logic to handle image grids more efficiently. * Refactor Ovis2 model test structure to include Ovis2Model * Add optional disable_grouping param to Ovis2ImageProcessorFast * Refactor type hints in Ovis2 modules * Add licensing information in Ovis2 modules and tests * Refactor Ovis2 model by removing unused methods * Refactor Ovis2 model tests by renaming test classes and removing skipped tests * Refactor Ovis2 model output classes * Refactor Ovis2 weight conversion and Update model embedding classes * Refactor Ovis2 model imports and remove unused functions * Enhance vision configuration extraction in Ovis2 weight conversion * Refactor Ovis2 model's forward method to remove interpolation option * Update Ovis2 model documentation * Refactor Ovis2 model input handling and tokenizer configuration * Update return type hints in Ovis2 model * Remove commented-out code * fix config for tests and remove key mappings * Update tokenizer configuration to use add_special_tokens method * skip torchscript * Fix image placeholder generation in Ovis2Processor * Refactor Ovis2 model to rename visual_table to visual_embeddings_table * Enhance Ovis2 model by adding vision_feature_select_strategy parameter * Refactor Ovis2 model weights conversion and architecture * Refactor Ovis2 model by removing vision_feature_select_strategy parameter * Update Ovis2 model examples * Refactor Ovis2 model * Update Ovis2 model * Update Ovis2 model configuration * Refactor Ovis2 model test setup * Refactor flash attention support * Refactor * Fix typo * Refactor * Refactor model classes * Update expected output in Ovis2 * Refactor docstrings * Fix * Fix * Fix * Update input in tests * Fix * Fix get_decoder method * Refactor * Refactor Ovis2 * Fix * Fix * Fix test * Add get_placeholder_mask * Refactor Ovis2 model tests * Fix * Refactor * Fix * Fix * Fix Ovis2 test --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -1077,6 +1077,8 @@
|
||||
title: OmDet-Turbo
|
||||
- local: model_doc/oneformer
|
||||
title: OneFormer
|
||||
- local: model_doc/ovis2
|
||||
title: Ovis2
|
||||
- local: model_doc/owlvit
|
||||
title: OWL-ViT
|
||||
- local: model_doc/owlv2
|
||||
|
||||
105
docs/source/en/model_doc/ovis2.md
Normal file
105
docs/source/en/model_doc/ovis2.md
Normal file
@@ -0,0 +1,105 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Ovis2
|
||||
|
||||
## Overview
|
||||
|
||||
The [Ovis2](https://github.com/AIDC-AI/Ovis) is an updated version of the [Ovis](https://arxiv.org/abs/2405.20797) model developed by the AIDC-AI team at Alibaba International Digital Commerce Group.
|
||||
|
||||
Ovis2 is the latest advancement in multi-modal large language models (MLLMs), succeeding Ovis1.6. It retains the architectural design of the Ovis series, which focuses on aligning visual and textual embeddings, and introduces major improvements in data curation and training methods.
|
||||
|
||||
<img src="https://cdn-uploads.huggingface.co/production/uploads/637aebed7ce76c3b834cea37/XB-vgzDL6FshrSNGyZvzc.png" width="600">
|
||||
|
||||
<small> Ovis2 architecture.</small>
|
||||
|
||||
This model was contributed by [thisisiron](https://huggingface.co/thisisiron).
|
||||
|
||||
## Usage example
|
||||
|
||||
```python
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from torchvision import io
|
||||
from typing import Dict
|
||||
from transformers.image_utils import load_images, load_video
|
||||
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).eval().to("cuda:0")
|
||||
processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Describe the image."},
|
||||
],
|
||||
},
|
||||
]
|
||||
url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
messages = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
print(messages)
|
||||
|
||||
inputs = processor(
|
||||
images=[image],
|
||||
text=messages,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda:0")
|
||||
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
## Ovis2Config
|
||||
|
||||
[[autodoc]] Ovis2Config
|
||||
|
||||
## Ovis2VisionConfig
|
||||
|
||||
[[autodoc]] Ovis2VisionConfig
|
||||
|
||||
## Ovis2Model
|
||||
|
||||
[[autodoc]] Ovis2Model
|
||||
|
||||
## Ovis2ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Ovis2ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Ovis2ImageProcessor
|
||||
|
||||
[[autodoc]] Ovis2ImageProcessor
|
||||
|
||||
## Ovis2ImageProcessorFast
|
||||
|
||||
[[autodoc]] Ovis2ImageProcessorFast
|
||||
|
||||
## Ovis2Processor
|
||||
|
||||
[[autodoc]] Ovis2Processor
|
||||
@@ -240,6 +240,7 @@ if TYPE_CHECKING:
|
||||
from .oneformer import *
|
||||
from .openai import *
|
||||
from .opt import *
|
||||
from .ovis2 import *
|
||||
from .owlv2 import *
|
||||
from .owlvit import *
|
||||
from .paligemma import *
|
||||
|
||||
@@ -280,6 +280,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("open-llama", "OpenLlamaConfig"),
|
||||
("openai-gpt", "OpenAIGPTConfig"),
|
||||
("opt", "OPTConfig"),
|
||||
("ovis2", "Ovis2Config"),
|
||||
("owlv2", "Owlv2Config"),
|
||||
("owlvit", "OwlViTConfig"),
|
||||
("paligemma", "PaliGemmaConfig"),
|
||||
@@ -707,6 +708,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("open-llama", "OpenLlama"),
|
||||
("openai-gpt", "OpenAI GPT"),
|
||||
("opt", "OPT"),
|
||||
("ovis2", "Ovis2"),
|
||||
("owlv2", "OWLv2"),
|
||||
("owlvit", "OWL-ViT"),
|
||||
("paligemma", "PaliGemma"),
|
||||
|
||||
@@ -139,6 +139,7 @@ else:
|
||||
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
|
||||
("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
|
||||
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
|
||||
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
|
||||
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
||||
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
|
||||
@@ -279,6 +279,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("open-llama", "OpenLlamaModel"),
|
||||
("openai-gpt", "OpenAIGPTModel"),
|
||||
("opt", "OPTModel"),
|
||||
("ovis2", "Ovis2Model"),
|
||||
("owlv2", "Owlv2Model"),
|
||||
("owlvit", "OwlViTModel"),
|
||||
("paligemma", "PaliGemmaModel"),
|
||||
@@ -948,6 +949,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("ovis2", "Ovis2ForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
|
||||
@@ -997,6 +999,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("ovis2", "Ovis2ForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("perception_lm", "PerceptionLMForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
|
||||
@@ -104,6 +104,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mm-grounding-dino", "GroundingDinoProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("ovis2", "Ovis2Processor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
("paligemma", "PaliGemmaProcessor"),
|
||||
|
||||
32
src/transformers/models/ovis2/__init__.py
Normal file
32
src/transformers/models/ovis2/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_ovis2 import *
|
||||
from .image_processing_ovis2 import *
|
||||
from .image_processing_ovis2_fast import *
|
||||
from .modeling_ovis2 import *
|
||||
from .processing_ovis2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
179
src/transformers/models/ovis2/configuration_ovis2.py
Normal file
179
src/transformers/models/ovis2/configuration_ovis2.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ..qwen2.configuration_qwen2 import Qwen2Config
|
||||
|
||||
|
||||
class Ovis2VisionConfig(PretrainedConfig):
|
||||
r"""This is the configuration class to store the configuration of a [`Ovis2VisionModel`]. It is used to instantiate a
|
||||
Ovis2VisionModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of Ovis2.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 2816):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the RMSNorm layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a learnable bias to the query, key, and value sequences at each attention head.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a learnable bias to the MLP layers.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
vocab_size (`int`, *optional*, defaults to 16384):
|
||||
Vocabulary size of the Vision Transformer.
|
||||
hidden_stride (`int`, *optional*, defaults to 1):
|
||||
The stride of the hidden layer in the Vision Transformer.
|
||||
num_visual_indicator_tokens (`int`, *optional*, defaults to 5):
|
||||
Number of visual indicator tokens.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated normal initializer for initializing all weight matrices.
|
||||
tokenize_function (`str`, *optional*, defaults to `"softmax"`):
|
||||
The function used to tokenize the visual indicator tokens.
|
||||
```"""
|
||||
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
intermediate_size: int = 2816,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 8,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 224,
|
||||
patch_size: int = 14,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
attention_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
hidden_act="silu",
|
||||
vocab_size=16384,
|
||||
hidden_stride=1,
|
||||
num_visual_indicator_tokens=5,
|
||||
initializer_range=0.02,
|
||||
tokenize_function="softmax",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.qkv_bias = qkv_bias
|
||||
self.mlp_bias = mlp_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_stride = hidden_stride
|
||||
self.num_visual_indicator_tokens = num_visual_indicator_tokens
|
||||
self.tokenize_function = tokenize_function
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Ovis2Config(PretrainedConfig):
|
||||
r"""This is the configuration class to store the configuration of a [`Ovis2ForConditionalGeneration`]. It is used to instantiate a
|
||||
Ovis2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of Ovis2.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
e.g. [thisisiron/Ovis2-1B-hf](https://huggingface.co/thisisiron/Ovis2-1B-hf)
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Ovis2VisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
||||
The config object or dictionary of the text backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151665):
|
||||
The image token id to encode the image prompt.
|
||||
visual_indicator_token_ids (`List[int]`, *optional*, defaults to `[151666, 151667, 151668, 151669, 151670]`):
|
||||
The visual indicator token ids to encode the image prompt.
|
||||
vocab_size (`int`, *optional*, defaults to 151643):
|
||||
Vocabulary size of the text model.
|
||||
hidden_size (`int`, *optional*, defaults to 1536):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
|
||||
```python
|
||||
>>> from transformers import Ovis2ForConditionalGeneration, Ovis2Config
|
||||
|
||||
>>> # Initializing a Ovis2 style configuration
|
||||
>>> configuration = Ovis2Config()
|
||||
|
||||
>>> # Initializing a model from the Ovis2-2B style configuration
|
||||
>>> model = Ovis2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "ovis2"
|
||||
sub_configs = {"text_config": Qwen2Config, "vision_config": Ovis2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
image_token_id=151665,
|
||||
visual_indicator_token_ids=[151666, 151667, 151668, 151669, 151670],
|
||||
vocab_size=151643,
|
||||
hidden_size=1536,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = Ovis2VisionConfig(**vision_config)
|
||||
elif isinstance(vision_config, Ovis2VisionConfig):
|
||||
self.vision_config = vision_config
|
||||
if vision_config is None:
|
||||
self.vision_config = Ovis2VisionConfig(num_visual_indicator_tokens=len(visual_indicator_token_ids))
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif isinstance(text_config, Qwen2Config):
|
||||
self.text_config = text_config
|
||||
elif text_config is None:
|
||||
self.text_config = Qwen2Config()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.image_token_id = image_token_id
|
||||
self.visual_indicator_token_ids = visual_indicator_token_ids
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Ovis2VisionConfig", "Ovis2Config"]
|
||||
404
src/transformers/models/ovis2/convert_ovis2_weights_to_hf.py
Normal file
404
src/transformers/models/ovis2/convert_ovis2_weights_to_hf.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from transformers.models.ovis2.configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
|
||||
from transformers.models.ovis2.image_processing_ovis2 import Ovis2ImageProcessor
|
||||
from transformers.models.ovis2.modeling_ovis2 import Ovis2ForConditionalGeneration
|
||||
from transformers.models.ovis2.processing_ovis2 import Ovis2Processor
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
|
||||
|
||||
# Constants
|
||||
CONTEXT_LENGTH = 32768 # multimodal_max_length
|
||||
|
||||
|
||||
# fmt: off
|
||||
|
||||
# Mapping from original model key patterns to HF key patterns
|
||||
ORIGINAL_TO_HF_MAPPING = {
|
||||
r"trunk.blocks\.(\d+)\.norm_1": r"encoder.layers.\1.rms_norm1",
|
||||
r"trunk.blocks\.(\d+)\.norm_2": r"encoder.layers.\1.rms_norm2",
|
||||
r"trunk.blocks\.(\d+)\.attn.proj": r"encoder.layers.\1.attention.out_proj",
|
||||
r"visual_tokenizer": r"model.vision_tower",
|
||||
r"backbone": r"transformer",
|
||||
r"preprocessor": r"embeddings",
|
||||
r"patchifier.proj": r"patch_embedding",
|
||||
r"patchifier.norm": r"rms_norm",
|
||||
r"trunk.post_trunk_norm": r"rms_norm",
|
||||
r"trunk.blocks": r"encoder.layers",
|
||||
r"mlp.fc1": r"ffn.gate_proj",
|
||||
r"mlp.fc2": r"ffn.down_proj",
|
||||
r"mlp.fc3": r"ffn.up_proj",
|
||||
r"head.0": r"head_linear",
|
||||
r"head.1": r"head_norm",
|
||||
r"vte.weight": r"model.visual_embeddings_table.weight",
|
||||
r"llm.model": r"model.language_model",
|
||||
r"llm.lm_head": r"lm_head",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
# Special tokens for the tokenizer
|
||||
SPECIAL_TOKENS = [
|
||||
"<IMG_ATOM>",
|
||||
"<IMG_START>",
|
||||
"<IMG_GRID>",
|
||||
"<IMG_COL>",
|
||||
"<IMG_ROW>",
|
||||
"<IMG_END>",
|
||||
]
|
||||
|
||||
# Configuration keys to ignore when converting
|
||||
UNNECESSARY_CONFIG_KEYS = [
|
||||
"_name_or_path",
|
||||
"_attn_implementation_autoset",
|
||||
"auto_map",
|
||||
"use_bfloat16",
|
||||
"use_flash_attn",
|
||||
"qk_normalization",
|
||||
"bias",
|
||||
"norm_type",
|
||||
]
|
||||
|
||||
# Chat template for the tokenizer
|
||||
CHAT_TEMPLATE = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||
"{% if message['content'] is string %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% else %}"
|
||||
"{% for content in message['content'] %}"
|
||||
"{% if content['type'] == 'image' %}"
|
||||
"{{ '<image>\n' }}"
|
||||
"{% elif content['type'] == 'text' %}"
|
||||
"{{ content['text'] }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% endif %}"
|
||||
"{{'<|im_end|>\n'}}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{'<|im_start|>assistant\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(model_name_or_path, save_dir):
|
||||
"""
|
||||
Create and configure a tokenizer for the Ovis2 model.
|
||||
|
||||
Args:
|
||||
model_name_or_path: Path to the source model or tokenizer
|
||||
save_dir: Directory to save the tokenizer to
|
||||
|
||||
Returns:
|
||||
The configured tokenizer
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, return_token_type_ids=False)
|
||||
tokenizer.model_max_length = CONTEXT_LENGTH
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
|
||||
tokenizer.chat_template = CHAT_TEMPLATE
|
||||
setattr(tokenizer, "image_token", "<IMG_ATOM>") # 151665
|
||||
setattr(tokenizer, "image_token_id", tokenizer.convert_tokens_to_ids(tokenizer.image_token))
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def create_image_processor(save_dir):
|
||||
"""
|
||||
Create and save an image processor for the Ovis2 model.
|
||||
|
||||
Args:
|
||||
save_dir: Directory to save the image processor to
|
||||
|
||||
Returns:
|
||||
The configured image processor
|
||||
"""
|
||||
image_processor = Ovis2ImageProcessor(
|
||||
crop_to_patches=True,
|
||||
size={"height": 448, "width": 448},
|
||||
)
|
||||
return image_processor
|
||||
|
||||
|
||||
def extract_vision_config_from_original(orig_config):
|
||||
"""
|
||||
Extract and format vision configuration from the original model config.
|
||||
|
||||
Args:
|
||||
orig_config: Original model configuration
|
||||
|
||||
Returns:
|
||||
dict: Cleaned vision configuration dictionary
|
||||
"""
|
||||
visual_tokenizer_config = orig_config.visual_tokenizer_config.to_dict()
|
||||
# backbone_config = visual_tokenizer_config.pop("backbone_config")
|
||||
|
||||
# Copy required fields from backbone config
|
||||
visual_tokenizer_config["hidden_size"] = orig_config.visual_tokenizer_config.backbone_config.hidden_size
|
||||
visual_tokenizer_config["intermediate_size"] = (
|
||||
orig_config.visual_tokenizer_config.backbone_config.intermediate_size
|
||||
)
|
||||
visual_tokenizer_config["num_attention_heads"] = (
|
||||
orig_config.visual_tokenizer_config.backbone_config.num_attention_heads
|
||||
)
|
||||
visual_tokenizer_config["num_hidden_layers"] = (
|
||||
orig_config.visual_tokenizer_config.backbone_config.num_hidden_layers
|
||||
)
|
||||
visual_tokenizer_config["rms_norm_eps"] = orig_config.visual_tokenizer_config.backbone_config.rms_norm_eps
|
||||
visual_tokenizer_config["image_size"] = orig_config.visual_tokenizer_config.backbone_config.image_size
|
||||
visual_tokenizer_config["num_channels"] = orig_config.visual_tokenizer_config.backbone_config.num_channels
|
||||
visual_tokenizer_config["patch_size"] = orig_config.visual_tokenizer_config.backbone_config.patch_size
|
||||
visual_tokenizer_config["qkv_bias"] = orig_config.visual_tokenizer_config.backbone_config.qkv_bias
|
||||
|
||||
# Remove unnecessary keys
|
||||
return {k: v for k, v in visual_tokenizer_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||
|
||||
|
||||
def get_ovis2_config(model_name_or_path):
|
||||
"""
|
||||
Create an Ovis2 configuration from the original model.
|
||||
|
||||
Args:
|
||||
model_name_or_path: Path to the original model
|
||||
|
||||
Returns:
|
||||
Ovis2Config: Configuration for the HF implementation
|
||||
"""
|
||||
orig_config = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
).config
|
||||
|
||||
# Extract and clean LLM config
|
||||
llm_config = orig_config.llm_config.to_dict()
|
||||
llm_config = {k: v for k, v in llm_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||
|
||||
# Extract and clean vision config
|
||||
visual_tokenizer_config = extract_vision_config_from_original(orig_config)
|
||||
|
||||
return Ovis2Config(
|
||||
text_config=Qwen2Config(**llm_config),
|
||||
vision_config=Ovis2VisionConfig(**visual_tokenizer_config),
|
||||
hidden_size=llm_config["hidden_size"],
|
||||
vocab_size=llm_config["vocab_size"],
|
||||
initializer_range=llm_config["initializer_range"],
|
||||
)
|
||||
|
||||
|
||||
def load_orig_state_dict(model_name_or_path):
|
||||
"""
|
||||
Load the state dictionary from the original model.
|
||||
|
||||
Args:
|
||||
model_name_or_path: Path to the original model
|
||||
|
||||
Returns:
|
||||
dict: Original model state dictionary
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
def convert_orig2hf(state_dict, dim):
|
||||
"""
|
||||
Convert original state dictionary keys to HF format.
|
||||
|
||||
Args:
|
||||
state_dict: Original state dictionary
|
||||
dim: Hidden dimension for splitting QKV weights
|
||||
|
||||
Returns:
|
||||
dict: Converted state dictionary for HF model
|
||||
"""
|
||||
new_state_dict = {}
|
||||
|
||||
for key, val in state_dict.items():
|
||||
orig_key = key
|
||||
|
||||
# Apply regex pattern replacements
|
||||
for pattern, replacement in ORIGINAL_TO_HF_MAPPING.items():
|
||||
key = re.sub(pattern, replacement, key)
|
||||
|
||||
# Handle special cases
|
||||
if "attn.qkv" in key:
|
||||
# Split QKV into separate Q, K, V matrices
|
||||
new_key_query = key.replace("attn.qkv", "attention.q_proj")
|
||||
new_state_dict[new_key_query] = state_dict[orig_key][:dim]
|
||||
|
||||
new_key_key = key.replace("attn.qkv", "attention.k_proj")
|
||||
new_state_dict[new_key_key] = state_dict[orig_key][dim : 2 * dim]
|
||||
|
||||
new_key_value = key.replace("attn.qkv", "attention.v_proj")
|
||||
new_state_dict[new_key_value] = state_dict[orig_key][-dim:]
|
||||
|
||||
elif "pos_embed" in key:
|
||||
new_key = key.replace("pos_embed", "position_embedding.weight")
|
||||
new_state_dict[new_key] = state_dict[orig_key][0]
|
||||
|
||||
else:
|
||||
new_state_dict[key] = val
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_model(model_name_or_path):
|
||||
"""
|
||||
Convert and save the model in HF format.
|
||||
|
||||
Args:
|
||||
model_name_or_path: Path to the original model
|
||||
save_dir: Directory to save the converted model
|
||||
|
||||
Returns:
|
||||
The converted model
|
||||
"""
|
||||
|
||||
config = get_ovis2_config(model_name_or_path)
|
||||
config.architectures = ["Ovis2ForConditionalGeneration"]
|
||||
|
||||
# Load and convert weights
|
||||
orig_state_dict = load_orig_state_dict(model_name_or_path)
|
||||
new_state_dict = convert_orig2hf(orig_state_dict, config.vision_config.hidden_size)
|
||||
|
||||
# Create model and load converted weights
|
||||
model = Ovis2ForConditionalGeneration(config)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
# Report any issues with weight loading
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
"""Process command line arguments and execute the conversion pipeline."""
|
||||
parser = argparse.ArgumentParser(description="Convert Ovis2 model to HF format")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="AIDC-AI/Ovis2-2B",
|
||||
choices=[
|
||||
"AIDC-AI/Ovis2-1B",
|
||||
"AIDC-AI/Ovis2-2B",
|
||||
"AIDC-AI/Ovis2-4B",
|
||||
"AIDC-AI/Ovis2-8B",
|
||||
"AIDC-AI/Ovis2-16B",
|
||||
"AIDC-AI/Ovis2-34B",
|
||||
],
|
||||
help="Location of original Ovis2 model",
|
||||
)
|
||||
parser.add_argument("--save_dir", default="Ovis2-2B-hf", help="Location to write HF model and processors")
|
||||
parser.add_argument("--hub_dir", default="thisisiron/Ovis2-2B-hf", help="Hub repository name if pushing to hub")
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether to push the converted model to the Hugging Face hub"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Execute conversion pipeline
|
||||
print(f"Converting model from {args.model_name_or_path} to {args.save_dir}")
|
||||
|
||||
# If already included in the transformers library, remove to avoid duplication.
|
||||
if "aimv2" in CONFIG_MAPPING_NAMES:
|
||||
CONFIG_MAPPING_NAMES.pop("aimv2")
|
||||
|
||||
tokenizer = create_tokenizer(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
save_dir=args.save_dir,
|
||||
)
|
||||
|
||||
image_processor = create_image_processor(
|
||||
save_dir=args.save_dir,
|
||||
)
|
||||
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
# Convert and save the model
|
||||
model = convert_model(model_name_or_path=args.model_name_or_path)
|
||||
model.save_pretrained(args.save_dir)
|
||||
|
||||
# Save the processor
|
||||
processor = Ovis2Processor(tokenizer=tokenizer, image_processor=image_processor, chat_template=CHAT_TEMPLATE)
|
||||
processor.save_pretrained(args.save_dir)
|
||||
|
||||
# Push to hub if requested
|
||||
if args.push_to_hub:
|
||||
processor.push_to_hub(args.hub_dir, use_temp_dir=True)
|
||||
model.push_to_hub(args.hub_dir, use_temp_dir=True)
|
||||
|
||||
model = (
|
||||
AutoModelForImageTextToText.from_pretrained(
|
||||
args.save_dir,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
.eval()
|
||||
.to("cuda:0")
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(args.save_dir)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Describe the image."},
|
||||
],
|
||||
},
|
||||
]
|
||||
url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
messages = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
print(messages)
|
||||
|
||||
inputs = processor(
|
||||
images=[image],
|
||||
text=messages,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda:0")
|
||||
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
print(output_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
573
src/transformers/models/ovis2/image_processing_ovis2.py
Normal file
573
src/transformers/models/ovis2/image_processing_ovis2.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Similar to image_processing_mllama.get_all_supported_aspect_ratios
|
||||
@lru_cache(maxsize=10)
|
||||
def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Computes all allowed aspect ratios for a given minimum and maximum number of input tiles.
|
||||
|
||||
This function calculates all possible arrangements of tiles that can be formed
|
||||
within the constraint of the minimum and maximum number of tiles. Each arrangement is
|
||||
represented by its aspect ratio (width/height) and the corresponding tile configuration.
|
||||
|
||||
Args:
|
||||
min_image_tiles (`int`):
|
||||
The minimum number of tiles allowed.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles allowed.
|
||||
|
||||
Returns:
|
||||
`List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
|
||||
configuration in terms of number of tiles.
|
||||
|
||||
Example:
|
||||
>>> get_all_supported_aspect_ratios(1, 4)
|
||||
[(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)]
|
||||
|
||||
"""
|
||||
aspect_ratios = []
|
||||
for width in range(1, max_image_tiles + 1):
|
||||
for height in range(1, max_image_tiles + 1):
|
||||
if width * height <= max_image_tiles and width * height >= min_image_tiles:
|
||||
aspect_ratios.append((width, height))
|
||||
|
||||
aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
return aspect_ratios
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_optimal_tiled_canvas(
|
||||
original_image_size: tuple[int, int],
|
||||
target_tile_size: tuple[int, int],
|
||||
min_image_tiles: int,
|
||||
max_image_tiles: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the
|
||||
original image aspect ratio.
|
||||
In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with
|
||||
more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily
|
||||
excessive tiling.
|
||||
"""
|
||||
possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles)
|
||||
|
||||
original_height, original_width = original_image_size
|
||||
target_tile_height, target_tile_width = target_tile_size
|
||||
aspect_ratio = original_width / original_height
|
||||
area = original_width * original_height
|
||||
|
||||
# find the grid with the best aspect ratio
|
||||
best_ratio_diff = float("inf")
|
||||
best_grid = (1, 1)
|
||||
for grid in possible_tile_arrangements:
|
||||
grid_aspect_ratio = grid[0] / grid[1]
|
||||
ratio_diff = abs(aspect_ratio - grid_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_grid = grid
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
# if the aspect ratio difference is the same, we favor the grid with more patches
|
||||
# until the area covered by the patches is more than twice the original image area
|
||||
if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]:
|
||||
best_grid = grid
|
||||
|
||||
return best_grid
|
||||
|
||||
|
||||
def compute_patch_covering_area(left: int, upper: int, right: int, lower: int, side: int) -> float:
|
||||
w = right - left
|
||||
h = lower - upper
|
||||
w, h = max(w, h), min(w, h)
|
||||
if w > side:
|
||||
h = h / w * side
|
||||
w = side
|
||||
return w * h
|
||||
|
||||
|
||||
def split_image_into_grid(h: int, w: int, grid: tuple[int, int]) -> list[tuple[int, int, int, int]]:
|
||||
row_height = h // grid[0]
|
||||
col_width = w // grid[1]
|
||||
return [
|
||||
(
|
||||
col * col_width,
|
||||
row * row_height,
|
||||
w if col == grid[1] - 1 else (col + 1) * col_width,
|
||||
h if row == grid[0] - 1 else (row + 1) * row_height,
|
||||
)
|
||||
for row in range(grid[0])
|
||||
for col in range(grid[1])
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_min_tile_covering_grid(
|
||||
image_size: tuple[int, int],
|
||||
target_patch_size: int,
|
||||
max_image_tiles: int,
|
||||
covering_threshold: float = 0.9,
|
||||
) -> tuple[int, int]:
|
||||
image_height, image_width = image_size
|
||||
image_area = image_width * image_height
|
||||
|
||||
candidate_tile_grids = get_all_supported_aspect_ratios(1, max_image_tiles)
|
||||
evaluated_grids = []
|
||||
sufficient_covering_grids = []
|
||||
|
||||
for tile_grid in candidate_tile_grids:
|
||||
tile_regions = split_image_into_grid(image_height, image_width, tile_grid)
|
||||
tile_covering_ratio = (
|
||||
sum([compute_patch_covering_area(*region, target_patch_size) for region in tile_regions]) / image_area
|
||||
)
|
||||
|
||||
evaluated_grids.append((tile_grid, tile_covering_ratio))
|
||||
if tile_covering_ratio > covering_threshold:
|
||||
sufficient_covering_grids.append((tile_grid, tile_covering_ratio))
|
||||
|
||||
if sufficient_covering_grids:
|
||||
# Prefer fewer tiles and higher covering ratio
|
||||
return sorted(sufficient_covering_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
|
||||
else:
|
||||
# Fallback: prefer higher covering even if below threshold
|
||||
return sorted(evaluated_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
|
||||
|
||||
|
||||
class Ovis2ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a Ovis2 image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
crop_to_patches (`bool`, *optional*, defaults to `False`):
|
||||
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
|
||||
`preprocess` method.
|
||||
min_patches (`int`, *optional*, defaults to 1):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
|
||||
max_patches (`int`, *optional*, defaults to 12):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||
`do_rescale` parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
use_covering_area_grid (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the covering area grid to determine the number of patches. Only has an effect if
|
||||
`crop_to_patches` is set to `True`. Can be overridden by the `use_covering_area_grid` parameter in the
|
||||
`preprocess` method.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
crop_to_patches: bool = False,
|
||||
min_patches: int = 1,
|
||||
max_patches: int = 12,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
use_covering_area_grid: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 384, "width": 384}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.crop_to_patches = crop_to_patches
|
||||
self.min_patches = min_patches
|
||||
self.max_patches = max_patches
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
crop_to_patches: Optional[bool] = None,
|
||||
min_patches: Optional[int] = None,
|
||||
max_patches: Optional[int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
use_covering_area_grid: bool = True,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Controls the size of the image after `resize`. The shortest edge of the image is resized to
|
||||
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||
crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`):
|
||||
Whether to crop the image to patches.
|
||||
min_patches (`int`, *optional*, defaults to `self.min_patches`):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`.
|
||||
max_patches (`int`, *optional*, defaults to `self.max_patches`):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
use_covering_area_grid (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the covering area grid to determine the number of patches. Only has an effect if
|
||||
`crop_to_patches` is set to `True`.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches
|
||||
min_patches = min_patches if min_patches is not None else self.min_patches
|
||||
max_patches = max_patches if max_patches is not None else self.max_patches
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
# PIL RGBA images are converted to RGB
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_rescale and is_scaled_image(images[0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if crop_to_patches and max_patches > 1:
|
||||
images = [
|
||||
self.crop_image_to_patches(
|
||||
image,
|
||||
min_patches=min_patches,
|
||||
max_patches=max_patches,
|
||||
patch_size=size,
|
||||
data_format=input_data_format,
|
||||
use_covering_area_grid=use_covering_area_grid,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
grids = [grid for _, grid in images]
|
||||
images = [image for images_list, _ in images for image in images_list]
|
||||
else:
|
||||
grids = [(1, 1)] * len(images)
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if do_resize:
|
||||
images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
images[i] = self.normalize(
|
||||
image=images[i],
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format)
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images, "grids": grids}, tensor_type=return_tensors)
|
||||
|
||||
return encoded_outputs
|
||||
|
||||
def crop_image_to_patches(
|
||||
self,
|
||||
images: np.ndarray,
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
use_covering_area_grid: bool = True,
|
||||
patch_size: Optional[Union[tuple, int, dict]] = None,
|
||||
data_format: ChannelDimension = None,
|
||||
covering_threshold: float = 0.9,
|
||||
):
|
||||
"""
|
||||
Crop the image to patches and return a list of cropped images.
|
||||
The number of patches and their grid arrangement are determined by the original image size,
|
||||
the target patch size and the minimum and maximum number of patches.
|
||||
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The image to be cropped.
|
||||
min_patches (`int`):
|
||||
The minimum number of patches to be extracted from the image.
|
||||
max_patches (`int`):
|
||||
The maximum number of patches to be extracted from the image.
|
||||
use_covering_area_grid (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the covering area grid to determine the number of patches.
|
||||
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
|
||||
The size of the output patches.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The format of the image data. If `None`, the format is inferred from the input image.
|
||||
covering_threshold (`float`, *optional*, defaults to `0.9`):
|
||||
The threshold for the covering area grid. If the covering area is less than this value, the grid is
|
||||
considered invalid.
|
||||
|
||||
Returns:
|
||||
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
|
||||
"""
|
||||
if data_format is None:
|
||||
data_format = infer_channel_dimension_format(images)
|
||||
images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format)
|
||||
patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
|
||||
original_height, original_width = images.shape[-2:]
|
||||
|
||||
if use_covering_area_grid:
|
||||
# Use the original OVIS2 approach: compute the minimal number of tiles that cover at least 90% of the image area
|
||||
num_columns, num_rows = get_min_tile_covering_grid(
|
||||
(original_height, original_width),
|
||||
target_patch_size=patch_size_height, # square patch size
|
||||
max_image_tiles=max_patches,
|
||||
covering_threshold=covering_threshold,
|
||||
)
|
||||
else:
|
||||
# find the closest aspect ratio to the target
|
||||
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||
(original_height, original_width),
|
||||
(patch_size_height, patch_size_width),
|
||||
min_patches,
|
||||
max_patches,
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = patch_size_width * num_columns
|
||||
target_height = patch_size_height * num_rows
|
||||
num_blocks = num_columns * num_rows
|
||||
|
||||
# resize the image so that each patch is of patch_size
|
||||
resized_image = self.resize(
|
||||
images,
|
||||
{"height": target_height, "width": target_width},
|
||||
data_format=ChannelDimension.FIRST,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
|
||||
# split the image into patches
|
||||
processed_images = []
|
||||
for i in range(num_blocks):
|
||||
column = i % num_columns
|
||||
row = i // num_columns
|
||||
box = (
|
||||
column * patch_size_width,
|
||||
row * patch_size_height,
|
||||
(column + 1) * patch_size_width,
|
||||
(row + 1) * patch_size_height,
|
||||
)
|
||||
# split the image
|
||||
patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
|
||||
patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST)
|
||||
processed_images.append(patch_image)
|
||||
|
||||
if len(processed_images) != 1:
|
||||
thumbnail_img = self.resize(
|
||||
images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
processed_images.insert(0, thumbnail_img)
|
||||
|
||||
return processed_images, (num_rows, num_columns)
|
||||
|
||||
|
||||
__all__ = ["Ovis2ImageProcessor"]
|
||||
254
src/transformers/models/ovis2/image_processing_ovis2_fast.py
Normal file
254
src/transformers/models/ovis2/image_processing_ovis2_fast.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from .image_processing_ovis2 import get_min_tile_covering_grid, get_optimal_tiled_canvas
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class Ovis2ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
Args:
|
||||
crop_to_patches (`bool`, *optional*, defaults to `False`):
|
||||
Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
|
||||
`preprocess` method.
|
||||
min_patches (`int`, *optional*, defaults to 1):
|
||||
The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
|
||||
max_patches (`int`, *optional*, defaults to 12):
|
||||
The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
|
||||
set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
|
||||
use_covering_area_grid (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the covering area grid to determine the number of patches. Only has an effect if
|
||||
`crop_to_patches` is set to `True`. Can be overridden by the `use_covering_area_grid` parameter in the
|
||||
`preprocess` method.
|
||||
"""
|
||||
|
||||
crop_to_patches: Optional[bool]
|
||||
min_patches: Optional[int]
|
||||
max_patches: Optional[int]
|
||||
use_covering_area_grid: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ovis2ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 384, "width": 384}
|
||||
default_to_square = None
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
crop_to_patches = False
|
||||
min_patches = 1
|
||||
max_patches = 12
|
||||
use_covering_area_grid = True
|
||||
valid_kwargs = Ovis2ImageProcessorKwargs
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[Ovis2ImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def crop_image_to_patches(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
use_covering_area_grid: bool = True,
|
||||
covering_threshold: float = 0.9,
|
||||
patch_size: Optional[Union[tuple, int, dict]] = None,
|
||||
interpolation: Optional["F.InterpolationMode"] = None,
|
||||
):
|
||||
"""
|
||||
Crop the images to patches and return a list of cropped images.
|
||||
The number of patches and their grid arrangement are determined by the original image size,
|
||||
the target patch size and the minimum and maximum number of patches.
|
||||
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The images to be cropped.
|
||||
min_patches (`int`):
|
||||
The minimum number of patches to be extracted from the image.
|
||||
max_patches (`int`):
|
||||
The maximum number of patches to be extracted from the image.
|
||||
use_covering_area_grid (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original OVIS2 approach: compute the minimal number of tiles that cover at least 90%
|
||||
of the image area. If `False`, the closest aspect ratio to the target is used.
|
||||
covering_threshold (`float`, *optional*, defaults to `0.9`):
|
||||
The threshold for the covering area. Only has an effect if `use_covering_area_grid` is set to `True`.
|
||||
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
|
||||
The size of the output patches.
|
||||
The format of the image data. If `None`, the format is inferred from the input image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
|
||||
Returns:
|
||||
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
|
||||
"""
|
||||
num_image = images.shape[0]
|
||||
patch_size_height, patch_size_width = patch_size.height, patch_size.width
|
||||
original_height, original_width = images.shape[-2:]
|
||||
|
||||
if use_covering_area_grid:
|
||||
# Use the original OVIS2 approach: compute the minimal number of tiles that cover at least 90% of the image area
|
||||
num_columns, num_rows = get_min_tile_covering_grid(
|
||||
(original_height, original_width),
|
||||
target_patch_size=patch_size_height, # square patch size
|
||||
max_image_tiles=max_patches,
|
||||
covering_threshold=covering_threshold,
|
||||
)
|
||||
else:
|
||||
# find the closest aspect ratio to the target
|
||||
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = patch_size_width * num_columns
|
||||
target_height = patch_size_height * num_rows
|
||||
num_blocks = num_columns * num_rows
|
||||
|
||||
# resize the image so that each patch is of patch_size
|
||||
resized_image = self.resize(
|
||||
images, SizeDict(height=target_height, width=target_width), interpolation=interpolation
|
||||
)
|
||||
# split the image into patches
|
||||
processed_images = []
|
||||
for i in range(num_blocks):
|
||||
column = i % num_columns
|
||||
row = i // num_columns
|
||||
box = (
|
||||
column * patch_size_width,
|
||||
row * patch_size_height,
|
||||
(column + 1) * patch_size_width,
|
||||
(row + 1) * patch_size_height,
|
||||
)
|
||||
# split the image
|
||||
patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
|
||||
processed_images.append(patch_image)
|
||||
|
||||
if len(processed_images) != 1:
|
||||
thumbnail_img = self.resize(images, patch_size, interpolation=interpolation)
|
||||
processed_images.insert(0, thumbnail_img)
|
||||
|
||||
processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous()
|
||||
grid = [[num_rows, num_columns] for _ in range(num_image)]
|
||||
|
||||
return processed_images, grid
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
crop_to_patches: bool,
|
||||
min_patches: int,
|
||||
max_patches: int,
|
||||
use_covering_area_grid: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
if crop_to_patches and max_patches > 1:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
grids = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
stacked_images, grid = self.crop_image_to_patches(
|
||||
stacked_images,
|
||||
min_patches,
|
||||
max_patches,
|
||||
patch_size=size,
|
||||
use_covering_area_grid=use_covering_area_grid,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
grids[shape] = grid
|
||||
images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
images = [image for images_list in images for image in images_list]
|
||||
grids = reorder_images(grids, grouped_images_index)
|
||||
else:
|
||||
grids = [[1, 1] for _ in range(len(images))]
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
stacked_images = self.center_crop(stacked_images, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(data={"pixel_values": processed_images, "grids": grids}, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Ovis2ImageProcessorFast"]
|
||||
902
src/transformers/models/ovis2/modeling_ovis2.py
Normal file
902
src/transformers/models/ovis2/modeling_ovis2.py
Normal file
@@ -0,0 +1,902 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/ovis2/modular_ovis2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_ovis2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
||||
from ..auto import AutoModel
|
||||
from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Llava outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class Ovis2ModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Ovis2 causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class Ovis2CausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Ovis2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Ovis2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Ovis2VisionMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Ovis2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
embeddings = self.rms_norm(embeddings)
|
||||
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Ovis2VisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Ovis2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Ovis2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Ovis2VisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.attention = Ovis2Attention(config)
|
||||
self.ffn = Ovis2MLP(config)
|
||||
self.rms_norm1 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.rms_norm2 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states = self.rms_norm1(hidden_states)
|
||||
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = hidden_states + attn_output
|
||||
norm_hidden_states = self.rms_norm2(hidden_states)
|
||||
mlp_output = self.ffn(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + mlp_output
|
||||
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
|
||||
|
||||
|
||||
class Ovis2VisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`Ovis2VisionEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: Ovis2VisionConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Ignore copy
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
class Ovis2VisionTransformer(nn.Module):
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = Ovis2VisionEmbeddings(config)
|
||||
self.encoder = Ovis2VisionEncoder(config)
|
||||
self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.rms_norm(last_hidden_state)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class Ovis2VisualEmbeddingTable(nn.Embedding):
|
||||
def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
|
||||
if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
|
||||
return super().forward(visual_tokens)
|
||||
return torch.matmul(visual_tokens, self.weight)
|
||||
|
||||
|
||||
class Ovis2PreTrainedModel(PreTrainedModel):
|
||||
config: Ovis2Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Ovis2VisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
|
||||
def hard_softmax(logits: torch.Tensor, dim: int):
|
||||
y_soft = logits.softmax(dim)
|
||||
# Straight through.
|
||||
index = y_soft.max(dim, keepdim=True)[1]
|
||||
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
||||
ret = y_hard - y_soft.detach() + y_soft
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class Ovis2VisionModel(Ovis2PreTrainedModel):
|
||||
config: Ovis2VisionConfig
|
||||
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.transformer = Ovis2VisionTransformer(config)
|
||||
self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
|
||||
self.vocab_size = config.vocab_size
|
||||
self.head_linear = nn.Linear(
|
||||
config.hidden_size * config.hidden_stride * config.hidden_stride,
|
||||
self.vocab_size - self.num_visual_indicator_tokens,
|
||||
bias=False,
|
||||
)
|
||||
self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
outputs = self.transformer(pixel_values)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
if self.config.hidden_stride > 1:
|
||||
num_images, seq_len, hidden_dim = last_hidden_state.shape
|
||||
hidden_stride = self.config.hidden_stride
|
||||
|
||||
sqrt_l = int(math.sqrt(seq_len))
|
||||
if sqrt_l * sqrt_l != seq_len:
|
||||
raise ValueError("Token sequence length must be a perfect square")
|
||||
|
||||
pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
|
||||
last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
|
||||
sqrt_l += pad_size
|
||||
|
||||
last_hidden_state = last_hidden_state.reshape(
|
||||
num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
|
||||
)
|
||||
last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
|
||||
last_hidden_state = last_hidden_state.reshape(
|
||||
num_images, -1, hidden_stride * hidden_stride * hidden_dim
|
||||
) # (n, (sqrt_l//hs)^2, hs^2*d)
|
||||
|
||||
logits = self.head_linear(last_hidden_state)
|
||||
logits = self.head_norm(logits)
|
||||
|
||||
if self.config.tokenize_function == "gumbel_argmax":
|
||||
prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
|
||||
elif self.config.tokenize_function == "st_argmax":
|
||||
prob_token = hard_softmax(logits, dim=-1)
|
||||
elif self.config.tokenize_function == "softmax":
|
||||
prob_token = nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
return prob_token
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
|
||||
"""
|
||||
)
|
||||
class Ovis2Model(Ovis2PreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {}
|
||||
|
||||
def __init__(self, config: Ovis2Config):
|
||||
super().__init__(config)
|
||||
self.vision_tower = Ovis2VisionModel(config.vision_config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
|
||||
|
||||
self.visual_vocab_size = config.vision_config.vocab_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.visual_indicator_token_ids = config.visual_indicator_token_ids
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
batch_size, img_seq_len, _ = image_features.shape
|
||||
padding_tensor = torch.zeros(
|
||||
(batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
|
||||
dtype=image_features.dtype,
|
||||
device=image_features.device,
|
||||
requires_grad=False,
|
||||
layout=image_features.layout,
|
||||
)
|
||||
image_features = torch.cat([image_features, padding_tensor], dim=2)
|
||||
image_features = self.visual_embeddings_table(image_features)
|
||||
|
||||
visual_indicator = torch.arange(
|
||||
self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
|
||||
self.visual_vocab_size,
|
||||
dtype=torch.long,
|
||||
).to(image_features.device)
|
||||
visual_indicator_features = self.visual_embeddings_table(visual_indicator)
|
||||
|
||||
return image_features, visual_indicator_features
|
||||
|
||||
def get_placeholder_mask(
|
||||
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
||||
):
|
||||
"""
|
||||
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
||||
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
||||
"""
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
return special_image_mask
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Ovis2ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features, visual_indicator_features = self.get_image_features(pixel_values=pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
image_features = image_features.reshape(-1, image_features.shape[-1])
|
||||
n_image_features = image_features.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
|
||||
if input_ids is None:
|
||||
mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
mask = mask.all(-1)
|
||||
else:
|
||||
mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
|
||||
|
||||
if mask.any():
|
||||
inputs_embeds[mask] = (
|
||||
visual_indicator_features[i]
|
||||
.expand_as(inputs_embeds[mask])
|
||||
.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return Ovis2ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: Ovis2Config):
|
||||
super().__init__(config)
|
||||
self.model = Ovis2Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
return self.lm_head
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
return self.model.get_image_features(pixel_values=pixel_values)
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
raise AttributeError("Not needed for Ovis2")
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Ovis2CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
|
||||
|
||||
>>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
|
||||
|
||||
>>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
|
||||
>>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
|
||||
"user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return Ovis2CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
__all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]
|
||||
443
src/transformers/models/ovis2/modular_ovis2.py
Normal file
443
src/transformers/models/ovis2/modular_ovis2.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
||||
from ..aimv2.modeling_aimv2 import Aimv2Attention, Aimv2EncoderLayer
|
||||
from ..auto import AutoModel
|
||||
from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
|
||||
from ..llava.modeling_llava import LlavaForConditionalGeneration, LlavaModel
|
||||
from ..llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast, LlavaNextModelOutputWithPast
|
||||
from ..siglip.modeling_siglip import SiglipEncoder, SiglipVisionEmbeddings
|
||||
from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
|
||||
|
||||
|
||||
def hard_softmax(logits: torch.Tensor, dim: int):
|
||||
y_soft = logits.softmax(dim)
|
||||
# Straight through.
|
||||
index = y_soft.max(dim, keepdim=True)[1]
|
||||
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
||||
ret = y_hard - y_soft.detach() + y_soft
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class Ovis2ModelOutputWithPast(LlavaNextModelOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Ovis2CausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Ovis2RMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Ovis2VisionMLP(LlamaMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Ovis2VisionEmbeddings(SiglipVisionEmbeddings):
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
def interpolate_pos_encoding(self):
|
||||
raise NotImplementedError("Not needed for Ovis2")
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
embeddings = self.rms_norm(embeddings)
|
||||
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class Ovis2VisionAttention(Aimv2Attention):
|
||||
pass
|
||||
|
||||
|
||||
class Ovis2VisionEncoderLayer(Aimv2EncoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class Ovis2VisionEncoder(SiglipEncoder):
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
|
||||
class Ovis2VisionTransformer(nn.Module):
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = Ovis2VisionEmbeddings(config)
|
||||
self.encoder = Ovis2VisionEncoder(config)
|
||||
self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.rms_norm(last_hidden_state)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class Ovis2VisualEmbeddingTable(nn.Embedding):
|
||||
def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
|
||||
if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
|
||||
return super().forward(visual_tokens)
|
||||
return torch.matmul(visual_tokens, self.weight)
|
||||
|
||||
|
||||
class Ovis2PreTrainedModel(PreTrainedModel):
|
||||
config: Ovis2Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Ovis2VisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
|
||||
class Ovis2VisionModel(Ovis2PreTrainedModel):
|
||||
config: Ovis2VisionConfig
|
||||
|
||||
def __init__(self, config: Ovis2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.transformer = Ovis2VisionTransformer(config)
|
||||
self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
|
||||
self.vocab_size = config.vocab_size
|
||||
self.head_linear = nn.Linear(
|
||||
config.hidden_size * config.hidden_stride * config.hidden_stride,
|
||||
self.vocab_size - self.num_visual_indicator_tokens,
|
||||
bias=False,
|
||||
)
|
||||
self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
outputs = self.transformer(pixel_values)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
if self.config.hidden_stride > 1:
|
||||
num_images, seq_len, hidden_dim = last_hidden_state.shape
|
||||
hidden_stride = self.config.hidden_stride
|
||||
|
||||
sqrt_l = int(math.sqrt(seq_len))
|
||||
if sqrt_l * sqrt_l != seq_len:
|
||||
raise ValueError("Token sequence length must be a perfect square")
|
||||
|
||||
pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
|
||||
last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
|
||||
sqrt_l += pad_size
|
||||
|
||||
last_hidden_state = last_hidden_state.reshape(
|
||||
num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
|
||||
)
|
||||
last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
|
||||
last_hidden_state = last_hidden_state.reshape(
|
||||
num_images, -1, hidden_stride * hidden_stride * hidden_dim
|
||||
) # (n, (sqrt_l//hs)^2, hs^2*d)
|
||||
|
||||
logits = self.head_linear(last_hidden_state)
|
||||
logits = self.head_norm(logits)
|
||||
|
||||
if self.config.tokenize_function == "gumbel_argmax":
|
||||
prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
|
||||
elif self.config.tokenize_function == "st_argmax":
|
||||
prob_token = hard_softmax(logits, dim=-1)
|
||||
elif self.config.tokenize_function == "softmax":
|
||||
prob_token = nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
return prob_token
|
||||
|
||||
|
||||
class Ovis2Model(LlavaModel):
|
||||
_checkpoint_conversion_mapping = {}
|
||||
|
||||
def __init__(self, config: Ovis2Config):
|
||||
super().__init__(config)
|
||||
self.vision_tower = Ovis2VisionModel(config.vision_config)
|
||||
self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
|
||||
|
||||
self.visual_vocab_size = config.vision_config.vocab_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.visual_indicator_token_ids = config.visual_indicator_token_ids
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
del self.multi_modal_projector
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
batch_size, img_seq_len, _ = image_features.shape
|
||||
padding_tensor = torch.zeros(
|
||||
(batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
|
||||
dtype=image_features.dtype,
|
||||
device=image_features.device,
|
||||
requires_grad=False,
|
||||
layout=image_features.layout,
|
||||
)
|
||||
image_features = torch.cat([image_features, padding_tensor], dim=2)
|
||||
image_features = self.visual_embeddings_table(image_features)
|
||||
|
||||
visual_indicator = torch.arange(
|
||||
self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
|
||||
self.visual_vocab_size,
|
||||
dtype=torch.long,
|
||||
).to(image_features.device)
|
||||
visual_indicator_features = self.visual_embeddings_table(visual_indicator)
|
||||
|
||||
return image_features, visual_indicator_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Ovis2ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features, visual_indicator_features = self.get_image_features(pixel_values=pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
image_features = image_features.reshape(-1, image_features.shape[-1])
|
||||
n_image_features = image_features.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
|
||||
if input_ids is None:
|
||||
mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
mask = mask.all(-1)
|
||||
else:
|
||||
mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
|
||||
|
||||
if mask.any():
|
||||
inputs_embeds[mask] = (
|
||||
visual_indicator_features[i]
|
||||
.expand_as(inputs_embeds[mask])
|
||||
.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return Ovis2ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ovis2ForConditionalGeneration(LlavaForConditionalGeneration, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {}
|
||||
|
||||
def __init__(self, config: Ovis2Config):
|
||||
super().__init__(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
raise AttributeError("Not needed for Ovis2")
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
return self.model.get_image_features(pixel_values=pixel_values)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Ovis2CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
|
||||
|
||||
>>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
|
||||
|
||||
>>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
|
||||
>>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
|
||||
"user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return Ovis2CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]
|
||||
181
src/transformers/models/ovis2/processing_ovis2.py
Normal file
181
src/transformers/models/ovis2/processing_ovis2.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Ovis2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"image_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class Ovis2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Ovis2 processor which wraps Ovis2 image processor and a Qwen2 tokenizer into a single processor.
|
||||
|
||||
[`Ovis2Processor`] offers all the functionalities of [`Ovis2VideoProcessor`], [`Ovis2ImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Ovis2Processor.__call__`] and [`~Ovis2Processor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`Ovis2ImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
image_seq_length (`int`, *optional*, defaults to 256):
|
||||
The number of image tokens to be used for each image in the input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
chat_template=None,
|
||||
image_token="<image>",
|
||||
image_seq_length=256,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_seq_length = image_seq_length
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
||||
**kwargs: Unpack[Ovis2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
Ovis2ImageProcessor's [`~Ovis2ImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Ovis2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
image_inputs = {}
|
||||
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
image_grids = image_inputs.pop("grids").tolist()
|
||||
text = self._expand_image_tokens(text, image_grids)
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
def _expand_image_tokens(
|
||||
self,
|
||||
text: list[TextInput],
|
||||
grids: list[list[int]],
|
||||
):
|
||||
processed_text = []
|
||||
grid_index = 0
|
||||
for sample in text:
|
||||
while "<image>" in sample:
|
||||
grid = grids[grid_index]
|
||||
row, col = grid[0], grid[1]
|
||||
placeholder = f"<IMG_START>{'<IMG_ATOM>' * self.image_seq_length}<IMG_GRID>"
|
||||
if row * col > 1:
|
||||
for r in range(row):
|
||||
for c in range(col):
|
||||
placeholder += f"{'<IMG_ATOM>' * self.image_seq_length}"
|
||||
if c < col - 1:
|
||||
placeholder += "<IMG_COL>"
|
||||
if r < row - 1:
|
||||
placeholder += "<IMG_ROW>"
|
||||
placeholder += "<IMG_END>"
|
||||
|
||||
sample = sample.replace("<image>", placeholder, 1)
|
||||
grid_index += 1
|
||||
processed_text.append(sample)
|
||||
return processed_text
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(tokenizer_input_names) + list(image_processor_input_names)
|
||||
|
||||
|
||||
__all__ = ["Ovis2Processor"]
|
||||
0
tests/models/ovis2/__init__.py
Normal file
0
tests/models/ovis2/__init__.py
Normal file
177
tests/models/ovis2/test_image_processing_ovis2.py
Normal file
177
tests/models/ovis2/test_image_processing_ovis2.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import Ovis2ImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Ovis2ImageProcessorFast
|
||||
|
||||
|
||||
class Ovis2ImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
do_pad=False,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_pad = do_pad
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
"do_pad": self.do_pad,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Ovis2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = Ovis2ImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = Ovis2ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = Ovis2ImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processor, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processor, "size"))
|
||||
self.assertTrue(hasattr(image_processor, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processor, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processor, "image_std"))
|
||||
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
|
||||
|
||||
def test_slow_fast_equivalence_crop_to_patches(self):
|
||||
dummy_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)[0]
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
# torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
def test_slow_fast_equivalence_batched_crop_to_patches(self):
|
||||
# Prepare image inputs so that we have two groups of images with equal resolution with a group of images with
|
||||
# different resolutions in between
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
# torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
def test_crop_to_patches(self):
|
||||
# test slow image processor
|
||||
image_processor = self.image_processor_list[0](**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)[0]
|
||||
processed_images, grid = image_processor.crop_image_to_patches(
|
||||
image,
|
||||
min_patches=1,
|
||||
max_patches=6,
|
||||
patch_size={"height": 20, "width": 20},
|
||||
)
|
||||
self.assertEqual(len(processed_images), 5)
|
||||
self.assertEqual(processed_images[0].shape[:2], (20, 20))
|
||||
self.assertEqual(len(grid), 2) # (row, col)
|
||||
|
||||
# test fast image processor (process batch)
|
||||
image_processor = self.image_processor_list[1](**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)[0]
|
||||
processed_images, grid = image_processor.crop_image_to_patches(
|
||||
image.unsqueeze(0),
|
||||
min_patches=1,
|
||||
max_patches=6,
|
||||
patch_size=SizeDict(height=20, width=20),
|
||||
)
|
||||
self.assertEqual(len(processed_images[0]), 5)
|
||||
self.assertEqual(processed_images.shape[-2:], (20, 20))
|
||||
self.assertEqual(len(grid[0]), 2)
|
||||
384
tests/models/ovis2/test_modeling_ovis2.py
Normal file
384
tests/models/ovis2/test_modeling_ovis2.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Ovis2Config,
|
||||
Ovis2ForConditionalGeneration,
|
||||
Ovis2Model,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Ovis2VisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
seq_length=7,
|
||||
text_config={
|
||||
"model_type": "qwen2",
|
||||
"seq_length": 7,
|
||||
"is_training": True,
|
||||
"use_labels": True,
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 64,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 54,
|
||||
"hidden_act": "gelu",
|
||||
"max_position_embeddings": 580,
|
||||
"initializer_range": 0.02,
|
||||
"num_labels": 3,
|
||||
"pad_token_id": 0,
|
||||
},
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"image_size": 32,
|
||||
"patch_size": 8,
|
||||
"num_channels": 3,
|
||||
"hidden_size": 64,
|
||||
"vocab_size": 99,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 54,
|
||||
"attention_dropout": 0.0,
|
||||
"hidden_act": "silu",
|
||||
"qkv_bias": False,
|
||||
"hidden_stride": 2,
|
||||
"tokenize_function": "softmax",
|
||||
},
|
||||
image_token_id=1,
|
||||
visual_indicator_token_ids=[],
|
||||
vocab_size=99,
|
||||
hidden_size=64,
|
||||
ignore_id=-100,
|
||||
):
|
||||
self.parent = parent
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.image_token_id = image_token_id
|
||||
self.visual_indicator_token_ids = visual_indicator_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.image_seq_length = (
|
||||
vision_config["image_size"] // (vision_config["patch_size"] * vision_config["hidden_stride"])
|
||||
) ** 2
|
||||
self.seq_length = seq_length + self.image_seq_length
|
||||
self.is_training = is_training
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
|
||||
def get_config(self):
|
||||
return Ovis2Config(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
image_token_id=self.image_token_id,
|
||||
visual_indicator_token_ids=self.visual_indicator_token_ids,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.vision_config["num_channels"],
|
||||
self.vision_config["image_size"],
|
||||
self.vision_config["image_size"],
|
||||
]
|
||||
)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
|
||||
vocab_range = self.vocab_size - 2
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_range) + 2
|
||||
input_ids[:, : self.image_seq_length] = config.image_token_id
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
|
||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
||||
labels[:, : self.image_seq_length] = self.ignore_id
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Ovis2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `Ovis2ForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
Ovis2Model,
|
||||
Ovis2ForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = {"image-text-to-text": Ovis2ForConditionalGeneration} if is_torch_available() else {}
|
||||
_is_composite = True
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Ovis2VisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Ovis2Config, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
class Ovis2IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf",
|
||||
)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
self.image = Image.open(requests.get(url, stream=True).raw)
|
||||
self.prompt_image = ""
|
||||
self.messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What do you see in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
self.text = self.processor.apply_chat_template(self.messages, add_generation_prompt=True, tokenize=False)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_small_model_integration_test(self):
|
||||
model = Ovis2ForConditionalGeneration.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf", torch_dtype="bfloat16", device_map=torch_device
|
||||
)
|
||||
|
||||
inputs = self.processor(images=self.image, text=self.text, return_tensors="pt").to(
|
||||
torch_device, torch.bfloat16
|
||||
)
|
||||
|
||||
self.assertTrue(inputs.input_ids.shape[1] == 1314) # should expand num-image-tokens times
|
||||
self.assertTrue(inputs.pixel_values.shape == torch.Size([5, 3, 448, 448]))
|
||||
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=64)
|
||||
EXPECTED_DECODED_TEXT = 'system\nYou are a helpful assistant.\nuser\n\nWhat do you see in this image?\nassistant\nI see two cats lying on a pink blanket. There are also two remote controls on the blanket.' # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
def test_small_model_integration_test_batch(self):
|
||||
model = Ovis2ForConditionalGeneration.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf", torch_dtype="bfloat16", device_map=torch_device
|
||||
)
|
||||
|
||||
inputs = self.processor(
|
||||
text=[self.text],
|
||||
images=self.image,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['system\nYou are a helpful assistant.\nuser\n\nWhat do you see in this image?\nassistant\nI see two cats lying on a pink blanket. There are also two remote controls on the blanket.'] # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
def test_small_model_integration_test_multi_image(self):
|
||||
# related to (#29835)
|
||||
model = Ovis2ForConditionalGeneration.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf",
|
||||
torch_dtype="bfloat16",
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What do you see in these images?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
text = self.processor.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)
|
||||
inputs = self.processor(text=text, images=[self.image, image], return_tensors="pt").to(
|
||||
torch_device, torch.bfloat16
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=40)
|
||||
EXPECTED_DECODED_TEXT = 'system\nYou are a helpful assistant.\nuser\n\n\nWhat do you see in these images?\nassistant\nIn the first image, I see two cats lying on a pink blanket with remote controls nearby. The second image shows a dog standing on a wooden floor near a kitchen cabinet.' # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
def test_small_model_integration_test_batch_different_resolutions(self):
|
||||
model = Ovis2ForConditionalGeneration.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf", torch_dtype="bfloat16", device_map=torch_device
|
||||
)
|
||||
|
||||
lowres_url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw).resize((320, 240))
|
||||
|
||||
inputs = self.processor(
|
||||
text=[self.text, self.text],
|
||||
images=[lowres_img, self.image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'system\nYou are a helpful assistant.\nuser\n\nWhat do you see in this image?\nassistant\nAnswer: I see a brown dog standing on a wooden floor in what appears to be a kitchen.',
|
||||
'system\nYou are a helpful assistant.\nuser\n\nWhat do you see in this image?\nassistant\nI see two cats lying on a pink blanket. There are also two remote controls on the blanket.'
|
||||
] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
def test_small_model_integration_test_batch_matches_single(self):
|
||||
model = Ovis2ForConditionalGeneration.from_pretrained(
|
||||
"thisisiron/Ovis2-2B-hf",
|
||||
torch_dtype="bfloat16",
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
text=[self.text, self.text],
|
||||
images=[self.image, lowres_img],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, torch.bfloat16)
|
||||
|
||||
inputs_single = self.processor(text=self.text, images=self.image, return_tensors="pt", padding=True).to(
|
||||
torch_device, torch.bfloat16
|
||||
)
|
||||
|
||||
output_batched = model.generate(**inputs_batched, max_new_tokens=50)
|
||||
output_single = model.generate(**inputs_single, max_new_tokens=50)
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output_batched[0], skip_special_tokens=True),
|
||||
self.processor.decode(output_single[0], skip_special_tokens=True),
|
||||
)
|
||||
118
tests/models/ovis2/test_processor_ovis2.py
Normal file
118
tests/models/ovis2/test_processor_ovis2.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_av, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Ovis2ImageProcessor,
|
||||
Ovis2Processor,
|
||||
Qwen2TokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
@require_vision
|
||||
class Ovis2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Ovis2Processor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = Ovis2ImageProcessor()
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained("thisisiron/Ovis2-1B-hf")
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = Ovis2Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {
|
||||
"chat_template": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'}}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<image>\n' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{'<|im_end|>\n'}}{% endfor %}{% if add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}",
|
||||
} # fmt: skip
|
||||
|
||||
def test_processor_to_json_string(self):
|
||||
processor = self.get_processor()
|
||||
obj = json.loads(processor.to_json_string())
|
||||
for key, value in self.prepare_processor_dict().items():
|
||||
# chat_tempalate are tested as a separate test because they are saved in separate files
|
||||
if key != "chat_template":
|
||||
self.assertEqual(obj[key], value)
|
||||
self.assertEqual(getattr(processor, key, None), value)
|
||||
|
||||
def test_chat_template_is_saved(self):
|
||||
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
|
||||
# chat templates aren't serialized to json in processors
|
||||
self.assertFalse("chat_template" in processor_dict_loaded)
|
||||
|
||||
# they have to be saved as separate file and loaded back from that file
|
||||
# so we check if the same template is loaded
|
||||
processor_dict = self.prepare_processor_dict()
|
||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_chat_template(self):
|
||||
processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-1B-hf")
|
||||
expected_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_dict(self):
|
||||
processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-1B-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = [[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 27, 1805, 397, 3838, 374, 6839, 304, 419, 2168, 30, 151645, 198, 151644, 77091, 198]] # fmt: skip
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
@@ -94,6 +94,7 @@ PRIVATE_MODELS = [
|
||||
"Glm4vVisionModel",
|
||||
"Glm4vMoeVisionModel",
|
||||
"EvollaSaProtPreTrainedModel",
|
||||
"Ovis2VisionModel",
|
||||
]
|
||||
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
|
||||
Reference in New Issue
Block a user