Add EoMT Model || 🚨 Fix Mask2Former loss calculation (#37610)

* Initial Commit

* up

* More changes

* up

* Only mask_logits mismatch

* close enough logits debug later

* fixes

* format

* Add dummy loss

* Close enough processing for semantic seg

* nit

* Added panoptic postprocessor

* refactor

* refactor

* finally fixed panoptic postprocessor

* temp update

* Refactor ForUniversalSegmentation class

* nits and config update

* Few fixes and inference matches

* change mapping

* Added training support but loss slightly off 🥲

* Loss is matching 😀

* update

* Initial tests skelton

* changes

* tests update

* more modular

* initial tests

* updates

* better docstrings

* changes

* proc tests passing :)

* Image processor update

* tiny change

* QOL changes

* Update test w.r.t latest attn refactor

* repo-consistency fixes

* up

* Image proc fix and integration tests :)

* docs update

* integration tests

* fix

* docs update 🥰

* minor fix

* Happy CI

* fix

* obvious refactoring

* refactoring w.r.t review

* Add fask image proc skelton

* Fast Image proc and cleanups

* Use more modular

* tests update

* Add more tests

* Nit

* QOL updates

* change init_weights to torch default

* add eager func coz of make style

* up

* changes

* typo fix

* Updates

* More deterministic tests

* More modular

* go more modular 🚀

* up

* dump

* add supprot for giant ckpts

* overhaul

* modular

* refactor

* instace seg is ready

* cleanup

* forgot this

* docs cleanup

* minor changes

* EoMT - > Eomt

* Happy CI

* remove redundant comment

* Change model references

* final change

* check annealing per block

* My other PR changes 😂

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Yaswanth Gali
2025-06-27 17:48:18 +05:30
committed by GitHub
parent 0106a50a6b
commit 1750c518dd
16 changed files with 4923 additions and 1 deletions

View File

@@ -737,6 +737,8 @@
title: EfficientFormer title: EfficientFormer
- local: model_doc/efficientnet - local: model_doc/efficientnet
title: EfficientNet title: EfficientNet
- local: model_doc/eomt
title: EoMT
- local: model_doc/focalnet - local: model_doc/focalnet
title: FocalNet title: FocalNet
- local: model_doc/glpn - local: model_doc/glpn

View File

@@ -0,0 +1,214 @@
<!--Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# EoMT
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
## Overview
The Encoder-only Mask Transformer (EoMT) model was introduced in the CVPR 2025 Highlight Paper [Your ViT is Secretly an Image Segmentation Model](https://www.tue-mps.org/eomt) by Tommie Kerssies, Niccolò Cavagnero, Alexander Hermans, Narges Norouzi, Giuseppe Averta, Bastian Leibe, Gijs Dubbelman, and Daan de Geus.
EoMT reveals Vision Transformers can perform image segmentation efficiently without task-specific components.
The abstract from the paper is the following:
*Vision Transformers (ViTs) have shown remarkable performance and scalability across various computer vision tasks. To apply single-scale ViTs to image segmentation, existing methods adopt a convolutional adapter to generate multi-scale features, a pixel decoder to fuse these features, and a Transformer decoder that uses the fused features to make predictions. In this paper, we show that the inductive biases introduced by these task-specific components can instead be learned by the ViT itself, given sufficiently large models and extensive pre-training. Based on these findings, we introduce the Encoder-only Mask Transformer (EoMT), which repurposes the plain ViT architecture to conduct image segmentation. With large-scale models and pre-training, EoMT obtains a segmentation accuracy similar to state-of-the-art models that use task-specific components. At the same time, EoMT is significantly faster than these methods due to its architectural simplicity, e.g., up to 4x faster with ViT-L. Across a range of model sizes, EoMT demonstrates an optimal balance between segmentation accuracy and prediction speed, suggesting that compute resources are better spent on scaling the ViT itself rather than adding architectural complexity.*
This model was contributed by [Yaswanth Gali](https://huggingface.co/yaswanthgali).
The original code can be found [here](https://github.com/tue-mps/eomt).
## Architecture Info
The `EoMT` model uses a DINOv2-pretrained Vision Transformer with **register tokens** as its backbone. EoMT simplifies the segmentation pipeline by relying solely on the encoder, eliminating the need for task-specific decoders commonly used in prior approaches.
Architecturally, EoMT introduces a small set of **learned queries** and a lightweight **mask prediction module**. These queries are injected into the final encoder blocks, enabling **joint attention** between image patches and object queries. During training, **masked attention** is applied to constrain each query to focus on its corresponding region—effectively mimicking cross-attention. This constraint is gradually phased out via a **mask annealing strategy**, allowing for **efficient, decoder-free inference** without compromising segmentation performance.
<div style="text-align: center;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/eomt_architecture.png"
alt="drawing" width="500"/>
</div>
The model supports semantic, instance, and panoptic segmentation using a unified architecture and task-specific post-processing.
## Usage Examples
Use the Hugging Face implementation of EoMT for inference with pre-trained models.
### Semantic Segmentation
The EoMT model performs semantic segmentation using sliding-window inference. The input image is resized such that the shorter side matches the target input size, then it is split into overlapping crops. Each crop is then passed through the model. After inference, the predicted logits from each crop are stitched back together and rescaled to the original image size to get the final segmentation mask.
> **Note:**
> If you want to use a custom target size for **semantic segmentation**, specify it in the following format:
> `{"shortest_edge": 512}`
> Notice that `longest_edge` is not provided here — this is intentional. For semantic segmentation, images are typically **scaled so that the shortest edge is greater than or equal to the target size** hence longest_edge is not necessary.
```python
import matplotlib.pyplot as plt
import requests
import torch
from PIL import Image
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
model_id = "tue-mps/ade20k_semantic_eomt_large_512"
processor = AutoImageProcessor.from_pretrained(model_id)
model = EomtForUniversalSegmentation.from_pretrained(model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(
images=image,
return_tensors="pt",
)
# Remove Patch Offsets from inputs — only used later for post-processing.
patch_offsets = inputs.pop("patch_offsets")
with torch.inference_mode():
outputs = model(**inputs)
# Prepare the original image size in the format (height, width)
original_image_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_semantic_segmentation(
outputs,
patch_offsets=patch_offsets,
original_image_sizes=original_image_sizes,
)
# Visualize the segmentation mask
plt.imshow(preds[0])
plt.axis("off")
plt.title("Semantic Segmentation")
plt.show()
```
### Instance Segmentation
The EoMT model performs instance segmentation using padded inference. The input image is resized so that the longer side matches the target input size, and the shorter side is zero-padded to form a square. The resulting mask and class logits are combined through post-processing (adapted from Mask2Former) to produce a unified instance segmentation map, along with segment metadata like segment id, class labels and confidence scores.
> **Note:**
> To use a custom target size, specify the size as a dictionary in the following format:
> `{"shortest_edge": 512, "longest_edge": 512}`
> For both instance and panoptic segmentation, input images will be **scaled and padded** to this target size.
```python
import matplotlib.pyplot as plt
import requests
import torch
from PIL import Image
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
model_id = "tue-mps/coco_instance_eomt_large_640"
processor = AutoImageProcessor.from_pretrained(model_id)
model = EomtForUniversalSegmentation.from_pretrained(model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(
images=image,
return_tensors="pt",
)
with torch.inference_mode():
outputs = model(**inputs)
# Prepare the original image size in the format (height, width)
original_image_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_instance_segmentation(
outputs,
original_image_sizes=original_image_sizes,
)
# Visualize the segmentation mask
plt.imshow(preds[0]["segmentation"])
plt.axis("off")
plt.title("Instance Segmentation")
plt.show()
```
### Panoptic Segmentation
The EoMT model performs panoptic segmentation using the same padded inference strategy as in instance segmentation. After padding and normalization, the model predicts both thing (instances) and stuff (amorphous regions) classes. The resulting mask and class logits are combined through post-processing (adapted from Mask2Former) to produce a unified panoptic segmentation map, along with segment metadata like segment id, class labels and confidence scores.
```python
import matplotlib.pyplot as plt
import requests
import torch
from PIL import Image
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
model_id = "tue-mps/coco_panoptic_eomt_large_640"
processor = AutoImageProcessor.from_pretrained(model_id)
model = EomtForUniversalSegmentation.from_pretrained(model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(
images=image,
return_tensors="pt",
)
with torch.inference_mode():
outputs = model(**inputs)
# Prepare the original image size in the format (height, width)
original_image_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_panoptic_segmentation(
outputs,
original_image_sizes=original_image_sizes,
)
# Visualize the panoptic segmentation mask
plt.imshow(preds[0]["segmentation"])
plt.axis("off")
plt.title("Panoptic Segmentation")
plt.show()
```
## EomtImageProcessor
[[autodoc]] EomtImageProcessor
- preprocess
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
## EomtImageProcessorFast
[[autodoc]] EomtImageProcessorFast
- preprocess
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
## EomtConfig
[[autodoc]] EomtConfig
## EomtForUniversalSegmentation
[[autodoc]] EomtForUniversalSegmentation
- forward

View File

@@ -122,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("emu3", "Emu3Config"), ("emu3", "Emu3Config"),
("encodec", "EncodecConfig"), ("encodec", "EncodecConfig"),
("encoder-decoder", "EncoderDecoderConfig"), ("encoder-decoder", "EncoderDecoderConfig"),
("eomt", "EomtConfig"),
("ernie", "ErnieConfig"), ("ernie", "ErnieConfig"),
("ernie_m", "ErnieMConfig"), ("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"), ("esm", "EsmConfig"),
@@ -501,6 +502,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("emu3", "Emu3"), ("emu3", "Emu3"),
("encodec", "EnCodec"), ("encodec", "EnCodec"),
("encoder-decoder", "Encoder decoder"), ("encoder-decoder", "Encoder decoder"),
("eomt", "EoMT"),
("ernie", "ERNIE"), ("ernie", "ERNIE"),
("ernie_m", "ErnieM"), ("ernie_m", "ErnieM"),
("esm", "ESM"), ("esm", "ESM"),

View File

@@ -84,6 +84,7 @@ else:
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")), ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
("efficientformer", ("EfficientFormerImageProcessor",)), ("efficientformer", ("EfficientFormerImageProcessor",)),
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
("fuyu", ("FuyuImageProcessor",)), ("fuyu", ("FuyuImageProcessor",)),

View File

@@ -854,6 +854,7 @@ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Universal Segmentation mapping # Model for Universal Segmentation mapping
("detr", "DetrForSegmentation"), ("detr", "DetrForSegmentation"),
("eomt", "EomtForUniversalSegmentation"),
("mask2former", "Mask2FormerForUniversalSegmentation"), ("mask2former", "Mask2FormerForUniversalSegmentation"),
("maskformer", "MaskFormerForInstanceSegmentation"), ("maskformer", "MaskFormerForInstanceSegmentation"),
("oneformer", "OneFormerForUniversalSegmentation"), ("oneformer", "OneFormerForUniversalSegmentation"),

View File

@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_eomt import *
from .image_processing_eomt import *
from .image_processing_eomt_fast import *
from .modeling_eomt import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,168 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/eomt/modular_eomt.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_eomt.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class EomtConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the EoMT
[tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the hidden representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads in each attention layer.
mlp_ratio (`int`, *optional*, defaults to 4):
Ratio of the MLP hidden dimensionality to the hidden size.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings and encoder.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to 640):
The size (resolution) of each input image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
layerscale_value (`float`, *optional*, defaults to 1.0):
Initial value for the LayerScale parameter.
drop_path_rate (`float`, *optional*, defaults to 0.0):
The stochastic depth rate (drop path) used during training.
num_upscale_blocks (`int`, *optional*, defaults to 2):
Number of upsampling blocks used in the decoder or segmentation head.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability applied after attention projection.
use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
Whether to use the SwiGLU feedforward neural network.
num_blocks (`int`, *optional*, defaults to 4):
Number of feature blocks or stages in the architecture.
no_object_weight (`float`, *optional*, defaults to 0.1):
Loss weight for the 'no object' class in panoptic/instance segmentation.
class_weight (`float`, *optional*, defaults to 2.0):
Loss weight for classification targets.
mask_weight (`float`, *optional*, defaults to 5.0):
Loss weight for mask prediction.
dice_weight (`float`, *optional*, defaults to 5.0):
Loss weight for the dice loss component.
train_num_points (`int`, *optional*, defaults to 12544):
Number of points to sample for mask loss computation during training.
oversample_ratio (`float`, *optional*, defaults to 3.0):
Oversampling ratio used in point sampling for mask training.
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
Ratio of points to sample based on importance during training.
num_queries (`int`, *optional*, defaults to 200):
Number of object queries in the Transformer.
num_register_tokens (`int`, *optional*, defaults to 4):
Number of learnable register tokens added to the transformer input.
Example:
```python
>>> from transformers import EomtConfig, EomtForUniversalSegmentation
>>> # Initialize configuration
>>> config = EomtConfig()
>>> # Initialize model
>>> model = EomtForUniversalSegmentation(config)
>>> # Access config
>>> config = model.config
```"""
model_type = "eomt"
def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
mlp_ratio=4,
hidden_act="gelu",
hidden_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-6,
image_size=640,
patch_size=16,
num_channels=3,
layerscale_value=1.0,
drop_path_rate=0.0,
num_upscale_blocks=2,
attention_dropout=0.0,
use_swiglu_ffn=False,
num_blocks=4,
no_object_weight: float = 0.1,
class_weight: float = 2.0,
mask_weight: float = 5.0,
dice_weight: float = 5.0,
train_num_points: int = 12544,
oversample_ratio: float = 3.0,
importance_sample_ratio: float = 0.75,
num_queries=200,
num_register_tokens=4,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.mlp_ratio = mlp_ratio
self.attention_dropout = attention_dropout
self.layerscale_value = layerscale_value
self.drop_path_rate = drop_path_rate
self.num_upscale_blocks = num_upscale_blocks
self.use_swiglu_ffn = use_swiglu_ffn
self.num_blocks = num_blocks
self.no_object_weight = no_object_weight
self.class_weight = class_weight
self.mask_weight = mask_weight
self.dice_weight = dice_weight
self.train_num_points = train_num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.num_queries = num_queries
self.num_register_tokens = num_register_tokens
__all__ = ["EomtConfig"]

View File

@@ -0,0 +1,340 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import re
from typing import Optional
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import EomtConfig, EomtForUniversalSegmentation, EomtImageProcessorFast
# fmt: off
MAPPINGS = {
# Embeddings
r"network.encoder.backbone.cls_token" : r"embeddings.cls_token",
r"network.encoder.backbone.reg_token" : r"embeddings.register_tokens",
r"network.encoder.backbone.pos_embed" : r"embeddings.position_embeddings.weight",
r"network.encoder.backbone.patch_embed.proj" : r"embeddings.patch_embeddings.projection",
# Encoder Block
r"network.encoder.backbone.blocks.(\d+).norm1" : r"layers.\1.norm1",
r"network.encoder.backbone.blocks.(\d+).attn.proj" : r"layers.\1.attention.out_proj",
r"network.encoder.backbone.blocks.(\d+).ls1.gamma" : r"layers.\1.layer_scale1.lambda1",
r"network.encoder.backbone.blocks.(\d+).norm2" : r"layers.\1.norm2",
r"network.encoder.backbone.blocks.(\d+).ls2.gamma" : r"layers.\1.layer_scale2.lambda1",
r"network.encoder.backbone.blocks.(\d+).attn" : r"layers.\1.attention",
# Others
r"network.q.weight" : r"query.weight",
r"network.class_head" : r"class_predictor",
r"network.upscale.(\d+).conv1" : r"upscale_block.block.\1.conv1",
r"network.upscale.(\d+).conv2" : r"upscale_block.block.\1.conv2",
r"network.upscale.(\d+).norm" : r"upscale_block.block.\1.layernorm2d",
r"network.mask_head.0" : r"mask_head.fc1",
r"network.mask_head.2" : r"mask_head.fc2",
r"network.mask_head.4" : r"mask_head.fc3",
r"network.encoder.backbone.norm" : r"layernorm",
r"network.attn_mask_probs" : r"attn_mask_probs",
}
# fmt: on
# Mappings for MLP layers, depending on the type of MLP used in ckpts.
MLP_MAPPINGS = {
"swiglu_ffn": {
r"network.encoder.backbone.blocks.(\d+).mlp.fc1": r"layers.\1.mlp.weights_in",
r"network.encoder.backbone.blocks.(\d+).mlp.fc2": r"layers.\1.mlp.weights_out",
},
"vanilla_mlp": {
r"network.encoder.backbone.blocks.(\d+).mlp": r"layers.\1.mlp",
},
}
def convert_old_keys_to_new_keys(state_dict):
keys_as_text = "\n".join(state_dict.keys())
new_keys_as_text = keys_as_text
for old, repl in MAPPINGS.items():
if repl is None:
new_keys_as_text = re.sub(old, "", new_keys_as_text)
else:
new_keys_as_text = re.sub(old, repl, new_keys_as_text)
output_dict = dict(zip(keys_as_text.split("\n"), new_keys_as_text.split("\n")))
return output_dict
def split_qkv_tensor(key, tensor):
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
new_keys = ["q_proj", "k_proj", "v_proj"]
split_size = tensor.shape[0] // 3
split_tensors = torch.split(tensor, split_size, dim=0)
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
def convert_state_dict_to_hf(state_dict):
"""Convert state dict keys to HF format."""
conversion_dict = convert_old_keys_to_new_keys(state_dict)
converted_state_dict = {}
for old_key, new_key in conversion_dict.items():
if new_key:
if "qkv" in new_key: # Detect merged attention keys and split them.
qkv_split_dict = split_qkv_tensor(new_key, state_dict[old_key])
converted_state_dict.update(qkv_split_dict)
else:
converted_state_dict[new_key] = state_dict[old_key]
for i in [
"network.encoder.pixel_mean",
"network.encoder.pixel_std",
]:
converted_state_dict.pop(i)
# Embeddings will not have initial dimension
pos_embed_key = "embeddings.position_embeddings.weight"
converted_state_dict[pos_embed_key] = converted_state_dict[pos_embed_key].squeeze(0)
return converted_state_dict
def ensure_model_downloaded(
repo_id: Optional[str] = None, revision: Optional[str] = None, local_dir: Optional[str] = None
) -> str:
"""
Ensures model files are downloaded locally, downloads them if not.
Returns path to local files.
Args:
repo_id: The Hugging Face model repo ID (required if local_dir not provided)
revision: Optional git revision to use
local_dir: Optional local directory path where model files should be stored/found
"""
if local_dir is not None:
if os.path.exists(local_dir):
print(f"Using provided local directory: {local_dir}")
else:
# Create the local directory if it doesn't exist
os.makedirs(local_dir, exist_ok=True)
print(f"Created local directory: {local_dir}")
if repo_id is None:
raise ValueError("Either repo_id or local_dir must be provided")
print(f"Ensuring {repo_id} (revision: {revision or 'latest'}) is downloaded...")
try:
# First try to find files locally
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=True, local_dir=local_dir)
print(f"Found model files locally at {download_dir}")
return download_dir
except Exception:
# If files not found locally, download them
print(f"Downloading model files for {repo_id}...")
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=False, local_dir=local_dir)
print(f"Downloaded model files to {download_dir}")
return download_dir
def load_model_state_dict(input_path: str) -> dict:
"""
Load model state dict, handling both single and sharded files.
"""
index_path = os.path.join(input_path, "pytorch_model.bin.index.json")
single_file_path = os.path.join(input_path, "pytorch_model.bin")
# Check if we have a sharded model
if os.path.exists(index_path):
print("Loading sharded model...")
state_dict = {}
with open(index_path, "r") as f:
index = json.load(f)
# Get unique shard files and load each one only once
unique_shard_files = sorted(set(index["weight_map"].values()))
for shard_file in unique_shard_files:
print(f"Loading shard {shard_file}...")
shard_path = os.path.join(input_path, shard_file)
shard_dict = torch.load(shard_path, map_location="cpu")
state_dict.update(shard_dict)
return state_dict
# Single file model
elif os.path.exists(single_file_path):
print("Loading single file model...")
return torch.load(single_file_path, map_location="cpu")
else:
raise ValueError(f"No model files found in {input_path}")
def convert_model(
repo_id=None,
local_dir=None,
output_dir=None,
output_hub_path=None,
safe_serialization=True,
revision=None,
):
"""Convert and save the model weights, processor, and configuration."""
if output_dir is None and output_hub_path is None:
raise ValueError("At least one of output_dir or output_hub_path must be specified")
if repo_id is None and local_dir is None:
raise ValueError("Either repo_id or local_dir must be specified")
# Create output directory if specified
if output_dir:
os.makedirs(output_dir, exist_ok=True)
print(f"Created/verified output directory: {output_dir}")
torch.set_default_dtype(torch.float16)
# Download or locate model files
input_path = ensure_model_downloaded(repo_id=repo_id, revision=revision, local_dir=local_dir)
with open(os.path.join(input_path, "config.json"), "r") as f:
config_data = json.load(f)
# Pop off unwanted keys
_ = config_data.pop("backbone", None)
config = EomtConfig(
**{
**config_data,
"layerscale_value": 1e-5,
}
)
if "semantic" in repo_id.split("_"):
size = {"shortest_edge": config.image_size, "longest_edge": None}
do_split_image = True
do_pad = False
else:
size = {"shortest_edge": config.image_size, "longest_edge": config.image_size}
do_split_image = False
do_pad = True
if "giant" in repo_id.split("_"):
config.use_swiglu_ffn = True
config.hidden_size = 1536
config.num_hidden_layers = 40
config.num_attention_heads = 24
# Update MAPPINGS for ckpts depending on the MLP type
MAPPINGS.update(MLP_MAPPINGS["swiglu_ffn"])
else:
MAPPINGS.update(MLP_MAPPINGS["vanilla_mlp"])
processor = EomtImageProcessorFast(size=size, do_split_image=do_split_image, do_pad=do_pad)
# Save the config and processor
if output_dir:
config.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
if output_hub_path:
config.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
# Initialize model with empty weights
print("Creating empty model...")
with init_empty_weights():
model = EomtForUniversalSegmentation(config)
# Load and convert state dict
print("Loading state dict...")
state_dict = load_model_state_dict(input_path)
state_dict = convert_state_dict_to_hf(state_dict)
# Load converted state dict
print("Loading converted weights into model...")
model.load_state_dict(state_dict, strict=True, assign=True)
# Save the model
if output_dir:
print(f"Saving model to {output_dir}...")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
if output_hub_path:
print(f"Pushing model to hub at {output_hub_path}...")
model.push_to_hub(output_hub_path, safe_serialization=safe_serialization)
del state_dict, model
gc.collect()
# Validate the saved model if saved locally
if output_dir:
print("Reloading the local model to check if it's saved correctly...")
EomtForUniversalSegmentation.from_pretrained(output_dir, device_map="auto")
print("Local model reloaded successfully.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_repo_id",
help="HuggingFace Hub repo ID for the model",
default=None,
)
parser.add_argument(
"--local_dir",
help="Local directory containing the model files",
default=None,
)
parser.add_argument(
"--revision",
help="Specific revision to download from the Hub",
default=None,
)
parser.add_argument(
"--output_dir",
help="Location to write HF model locally",
default=None,
)
parser.add_argument(
"--output_hub_path",
help="Repository ID to push model to hub (e.g. 'username/model-name')",
default=None,
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to save using safetensors",
)
args = parser.parse_args()
if args.output_dir is None and args.output_hub_path is None:
raise ValueError("At least one of --output_dir or --output_hub_path must be specified")
if args.hf_repo_id is None and args.local_dir is None:
raise ValueError("Either --hf_repo_id or --local_dir must be specified")
convert_model(
repo_id=args.hf_repo_id,
local_dir=args.local_dir,
output_dir=args.output_dir,
output_hub_path=args.output_hub_path,
safe_serialization=args.safe_serialization,
revision=args.revision,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,972 @@
# coding=utf-8
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for EoMT."""
import math
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
PaddingMode,
pad,
resize,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
make_flat_list_of_images,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
TensorType,
filter_out_non_signature_kwargs,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
import torch.nn.functional as F
# Adapted from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks(
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
ignore_index: Optional[int] = None,
):
if ignore_index is not None:
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
# Get unique ids (class or instance ids based on input)
all_labels = np.unique(segmentation_map)
# Drop background label if applicable
if ignore_index is not None:
all_labels = all_labels[all_labels != ignore_index]
# Generate a binary mask for each object instance
binary_masks = [(segmentation_map == i) for i in all_labels]
# Stack the binary masks
if binary_masks:
binary_masks = np.stack(binary_masks, axis=0)
else:
binary_masks = np.zeros((0, *segmentation_map.shape))
# Convert instance ids to class ids
if instance_id_to_semantic_id is not None:
labels = np.zeros(all_labels.shape[0])
for label in all_labels:
class_id = instance_id_to_semantic_id[label + 1 if ignore_index is not None else label]
labels[all_labels == label] = class_id - 1 if ignore_index is not None else class_id
else:
labels = all_labels
return binary_masks.astype(np.float32), labels.astype(np.int64)
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
"""
Computes the output image size given the input image size and the desired output size.
Args:
image_size (`Tuple[int, int]`):
The input image size.
size (`int`):
The desired output size.
max_size (`int`, *optional*):
The maximum allowed output size.
"""
height, width = image_size
raw_size = None
if max_size is not None:
min_original_size = float(min((height, width)))
max_original_size = float(max((height, width)))
if max_original_size / min_original_size * size > max_size:
raw_size = max_size * min_original_size / max_original_size
size = int(round(raw_size))
if (height <= width and height == size) or (width <= height and width == size):
oh, ow = height, width
elif width < height:
ow = size
if max_size is not None and raw_size is not None:
oh = round(raw_size * height / width)
else:
oh = round(size * height / width)
else:
oh = size
if max_size is not None and raw_size is not None:
ow = round(raw_size * width / height)
else:
ow = round(size * width / height)
return (oh, ow)
# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
"""
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
`labels`.
Args:
masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`):
A tensor of shape `(num_queries)`.
labels (`torch.Tensor`):
A tensor of shape `(num_queries)`.
object_mask_threshold (`float`):
A number between 0 and 1 used to binarize the masks.
Raises:
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
Returns:
`tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
< `object_mask_threshold`.
"""
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
raise ValueError("mask, scores and labels must have the same shape!")
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
return masks[to_keep], scores[to_keep], labels[to_keep]
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
# Get the mask associated with the k class
mask_k = mask_labels == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_mask = mask_probs[k] >= mask_threshold
original_area = original_mask.sum()
final_mask = mask_k & original_mask
final_mask_area = final_mask.sum()
mask_exists = mask_k_area > 0 and original_area > 0 and final_mask_area > 0
if mask_exists:
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
mask_exists = False
return mask_exists, final_mask
def compute_segments(
mask_probs,
pred_scores,
pred_labels,
stuff_classes,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
target_size: Optional[tuple[int, int]] = None,
):
height = mask_probs.shape[1] if target_size is None else target_size[0]
width = mask_probs.shape[2] if target_size is None else target_size[1]
segmentation = torch.zeros((height, width), dtype=torch.long, device=mask_probs.device) - 1
segments: list[dict] = []
# Compute per-pixel assignment based on weighted mask scores
mask_probs = mask_probs.sigmoid()
mask_labels = (pred_scores[:, None, None] * mask_probs).argmax(0)
# Keep track of instances of each class
current_segment_id = 0
stuff_memory_list: dict[str, int] = {}
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
# Check if mask exists and large enough to be a segment
mask_exists, final_mask = check_segment_validity(
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
)
if not mask_exists:
continue
if stuff_classes and pred_class in stuff_classes:
if pred_class in stuff_memory_list:
segmentation[final_mask] = stuff_memory_list[pred_class]
continue
else:
stuff_memory_list[pred_class] = current_segment_id
segmentation[final_mask] = current_segment_id
segment_score = round(pred_scores[k].item(), 6)
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"score": segment_score,
}
)
current_segment_id += 1
return segmentation, segments
def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]:
"""Returns the height and width from a size dict."""
target_height = size_dict["shortest_edge"]
target_width = size_dict.get("longest_edge", None) or target_height
return target_height, target_width
class EomtImageProcessor(BaseImageProcessor):
r"""
Constructs a EoMT image processor. The image processor can be used to prepare image(s) and optional targets
for the model.
This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int`, *optional*, defaults to 640):
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
height / width, size)`.
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the input to a certain `scale`.
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
do_split_image (`bool`, *optional*, defaults to `False`):
Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the
input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches.
Otherwise, the input images will be padded to the target size.
do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
denoted with 0 (background) will be replaced with `ignore_index`.
num_labels (`int`, *optional*):
The number of labels in the segmentation map.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
do_split_image: bool = False,
do_pad: bool = False,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
ignore_index: Optional[int] = None,
num_labels: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 640, "longest_edge": 640}
size = get_size_dict(size, default_to_square=False)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_split_image = do_split_image
self.do_pad = do_pad
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.ignore_index = ignore_index
self.num_labels = num_labels
def resize(
self,
image: np.ndarray,
size: dict,
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
image_size = get_image_size(image)
output_size = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
image = resize(
image=image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
return_numpy=True,
**kwargs,
)
return image
def _split_image(self, image: ImageInput, size: dict, image_index: int) -> tuple[list, list]:
"""Slices an image into overlapping patches for semantic segmentation."""
patches, patch_offsets = [], []
image_size = get_image_size(image)
patch_size = size["shortest_edge"]
longer_side = max(image_size)
num_patches = math.ceil(longer_side / patch_size)
total_overlap = num_patches * patch_size - longer_side
overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0
for i in range(num_patches):
start = int(i * (patch_size - overlap_per_patch))
end = start + patch_size
if image_size[0] > image_size[1]:
patch = image[:, start:end, :]
else:
patch = image[:, :, start:end]
patches.append(patch)
patch_offsets.append([image_index, start, end])
return patches, patch_offsets
def _pad(self, image: ImageInput, size: dict) -> np.ndarray:
"""Pads the image to the target size using zero padding."""
height, width = get_image_size(image)
target_height, target_width = get_target_size(size)
pad_h = max(0, target_height - height)
pad_w = max(0, target_width - width)
padding = ((0, pad_h), (0, pad_w))
# Channel axis is last; default padding format is compatible
padded_image = pad(image=image, padding=padding, mode=PaddingMode.CONSTANT, constant_values=0.0)
return padded_image
def _preprocess_images(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None,
do_split_image: Optional[bool] = None,
do_pad: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a batch of images."""
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [
self.resize(
image,
size=size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images
]
processed_images, patch_offsets = [], []
if do_split_image:
for idx, img in enumerate(images):
patches, offsets = self._split_image(img, size, idx)
processed_images.extend(patches)
patch_offsets.extend(offsets)
images = processed_images
if do_pad:
images = [self._pad(img, size) for img in images]
if do_rescale:
images = [self.rescale(img, scale=rescale_factor, input_data_format=input_data_format) for img in images]
if do_normalize:
images = [
self.normalize(
image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
for image in images
]
return images, patch_offsets
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: Optional[bool] = False,
do_pad: Optional[bool] = False,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None,
data_format: Union[str, ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
# Add channel dimension if missing - needed for certain transformations
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map)
if do_resize:
segmentation_map = self.resize(
segmentation_map,
size=size,
resample=resample,
data_format=data_format,
)
if do_pad:
segmentation_map = self._pad(segmentation_map, size)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
return torch.from_numpy(segmentation_map)
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[Union[list[dict[int, int]], dict[int, int]]] = None,
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
do_split_image: Optional[bool] = None,
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
do_pad: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
ignore_index: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
Preprocesses images or a batch of images.
Args:
images (`ImageInput`):
Image or batch of images to preprocess.
segmentation_maps (`ImageInput`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations.
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
A mapping between object instance ids and class ids.
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
Whether to split the input images into overlapping patches for semantic segmentation.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the input images.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the input images by `rescale_factor`.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Factor to scale image pixel values.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the input images.
do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Mean for normalization. Single value or list for each channel.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Standard deviation for normalization. Single value or list for each channel.
ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
denoted with 0 (background) will be replaced with `ignore_index`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be `"pt"`, `"tf"`, `"np"`, or `"jax"`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
Channel format of the output image. Either `"channels_first"` or `"channels_last"`.
input_data_format (`ChannelDimension` or `str`, *optional*):
Channel format of the input image.
"""
do_split_image = do_split_image if do_split_image is not None else self.do_split_image
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_pad = do_pad if do_pad is not None else self.do_pad
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
pixel_values_list, patch_offsets = self._preprocess_images(
images=images,
do_resize=do_resize,
size=size,
resample=resample,
do_split_image=do_split_image,
do_pad=do_pad,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
segmentation_maps = [to_numpy_array(mask) for mask in segmentation_maps]
segmentation_maps = [
self._preprocess_mask(
segmentation_map,
do_resize=do_resize,
do_pad=do_pad,
size=size,
resample=PILImageResampling.NEAREST,
data_format=data_format,
input_data_format=input_data_format,
)
for segmentation_map in segmentation_maps
]
encoded_inputs = self.encode_inputs(
pixel_values_list,
segmentation_maps,
instance_id_to_semantic_id,
ignore_index,
return_tensors,
input_data_format=data_format,
)
if do_split_image and patch_offsets:
encoded_inputs["patch_offsets"] = patch_offsets
return encoded_inputs
def encode_inputs(
self,
pixel_values_list: list[ImageInput],
segmentation_maps: ImageInput = None,
instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None,
ignore_index: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
EoMT addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
each mask.
Args:
pixel_values_list (`List[ImageInput]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`.
segmentation_maps (`ImageInput`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations.
(`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
instance segmentation map where each pixel represents an instance id. Can be provided as a single
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
instance ids in each image separately.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
(when `annotations` are provided).
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
`mask_labels[i][j]` if `class_labels[i][j]`.
"""
ignore_index = self.ignore_index if ignore_index is None else ignore_index
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
if input_data_format is None:
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
encoded_inputs = BatchFeature({"pixel_values": pixel_values_list}, tensor_type=return_tensors)
if segmentation_maps is not None:
mask_labels = []
class_labels = []
# Convert to list of binary masks and labels
for idx, segmentation_map in enumerate(segmentation_maps):
segmentation_map = to_numpy_array(segmentation_map)
if isinstance(instance_id_to_semantic_id, list):
instance_id = instance_id_to_semantic_id[idx]
else:
instance_id = instance_id_to_semantic_id
# Use instance2class_id mapping per image
masks, classes = convert_segmentation_map_to_binary_masks(
segmentation_map,
instance_id,
ignore_index=ignore_index,
)
mask_labels.append(torch.from_numpy(masks))
class_labels.append(torch.from_numpy(classes))
# we cannot batch them since they don't share a common class size
encoded_inputs["mask_labels"] = mask_labels
encoded_inputs["class_labels"] = class_labels
return encoded_inputs
def merge_image_patches(
self,
segmentation_logits: torch.Tensor,
patch_offsets: list[tuple[int, int, int]],
original_image_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""
Reconstructs full-size semantic segmentation logits from patch predictions.
Args:
segmentation_logits (`torch.Tensor`):
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
for each image patch.
patch_offsets (`List[Tuple[int, int, int]]`):
A list of tuples where each tuple contains:
- `image_index` (int): Index of the original image this patch belongs to.
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
- `end` (int): End pixel index of the patch along the long dimension.
original_image_sizes (`List[Tuple[int, int]]`):
List of original (height, width) dimensions for each image before preprocessing.
size (`Dict[str, int]`):
A size dict which was used to resize.
"""
num_classes = segmentation_logits.shape[1]
aggregated_logits = []
patch_counts = []
for image_size in original_image_sizes:
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
# Stitch patches back into full-sized logit maps
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
else:
aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, :, patch_start:patch_end] += 1
# Normalize and resize logits to original image size
reconstructed_logits = []
for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)):
averaged_logits = logit_sum / count.clamp(min=1)
resized_logits = F.interpolate(
averaged_logits[None, ...],
size=original_image_sizes[idx],
mode="bilinear",
align_corners=False,
)[0]
reconstructed_logits.append(resized_logits)
return reconstructed_logits
def unpad_image(
self,
segmentation_logits: torch.Tensor,
original_image_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""Restores panoptic segmentation logits to their original image resolutions."""
resized_logits = []
for idx, original_size in enumerate(original_image_sizes):
target_height, target_width = get_size_with_aspect_ratio(
original_size, size["shortest_edge"], size["longest_edge"]
)
cropped_logits = segmentation_logits[idx][:, :target_height, :target_width]
upsampled_logits = F.interpolate(
cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False
)[0]
resized_logits.append(upsampled_logits)
return resized_logits
def post_process_semantic_segmentation(
self,
outputs,
patch_offsets: list[tuple[int, int, int]],
original_image_sizes: list[tuple[int, int]],
size: Optional[dict[str, int]] = None,
) -> np.ndarray:
"""Post-processes model outputs into final semantic segmentation prediction."""
size = size if size is not None else self.size
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
output_size = get_target_size(size)
masks_queries_logits = F.interpolate(
masks_queries_logits,
size=output_size,
mode="bilinear",
)
# Remove the null class `[..., :-1]`
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
preds = torch.stack(output_logits).argmax(dim=1)
return preds
def post_process_panoptic_segmentation(
self,
outputs,
original_image_sizes: list[tuple[int, int]],
threshold: float = 0.8,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
stuff_classes: Optional[list[int]] = None,
size: Optional[dict[str, int]] = None,
):
"""Post-processes model outputs into final panoptic segmentation prediction."""
size = size if size is not None else self.size
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
output_size = get_target_size(size)
masks_queries_logits = F.interpolate(
masks_queries_logits,
size=output_size,
mode="bilinear",
)
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
results: list = []
for i in range(batch_size):
mask_probs, pred_scores, pred_labels = remove_low_and_no_objects(
mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels
)
# No mask found
if mask_probs.shape[0] <= 0:
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
segmentation, segments = compute_segments(
mask_probs=mask_probs,
pred_scores=pred_scores,
pred_labels=pred_labels,
stuff_classes=stuff_classes,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
def post_process_instance_segmentation(
self,
outputs,
original_image_sizes: list[tuple[int, int]],
threshold: float = 0.5,
size: Optional[dict[str, int]] = None,
):
"""Post-processes model outputs into Instance Segmentation Predictions."""
size = size if size is not None else self.size
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
output_size = get_target_size(size)
masks_queries_logits = F.interpolate(
masks_queries_logits,
size=output_size,
mode="bilinear",
)
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
device = masks_queries_logits.device
batch_size = class_queries_logits.shape[0]
num_queries = class_queries_logits.shape[-2]
results = []
for i in range(batch_size):
mask_pred = mask_probs_batch[i]
mask_class = class_queries_logits[i]
# Remove the null class `[..., :-1]`
scores, pred_classes = mask_class.softmax(dim=-1)[..., :-1].max(-1)
pred_masks = (mask_pred > 0).float()
# Calculate average mask prob
mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
pred_masks.flatten(1).sum(1) + 1e-6
)
pred_scores = scores * mask_scores
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
instance_maps, segments = [], []
current_segment_id = 0
for j in range(num_queries):
score = pred_scores[j].item()
if not torch.all(pred_masks[j] == 0) and score >= threshold:
segmentation[pred_masks[j] == 1] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_classes[j].item(),
"score": round(score, 6),
}
)
current_segment_id += 1
instance_maps.append(pred_masks[j])
results.append({"segmentation": segmentation, "segments_info": segments})
return results
__all__ = ["EomtImageProcessor"]

View File

@@ -0,0 +1,580 @@
# coding=utf-8
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for EoMT."""
import math
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
make_list_of_images,
pil_torch_interpolation_mapping,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
from .image_processing_eomt import (
compute_segments,
convert_segmentation_map_to_binary_masks,
get_size_with_aspect_ratio,
remove_low_and_no_objects,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class EomtImageProcessorFastKwargs(DefaultFastImageProcessorKwargs):
"""
do_split_image (`bool`, *optional*, defaults to `False`):
Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the
input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches.
Otherwise, the input images will be padded to the target size.
do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
denoted with 0 (background) will be replaced with `ignore_index`.
"""
do_split_image: bool
do_pad: bool
ignore_index: Optional[int] = None
def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]:
"""Returns the height and width from a size dict."""
target_height = size_dict["shortest_edge"]
target_width = size_dict["longest_edge"] or target_height
return target_height, target_width
def reorder_patches_and_offsets(
patches: list[torch.Tensor], offsets: list[list[int]]
) -> tuple[list[torch.Tensor], list[list[int]]]:
"""Sorts patches and offsets according to the original image index."""
combined = list(zip(offsets, patches))
combined.sort(key=lambda x: x[0][0])
sorted_offsets, sorted_patches = zip(*combined)
return list(sorted_patches), list(sorted_offsets)
@auto_docstring
class EomtImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
size = {"shortest_edge": 640, "longest_edge": 640}
default_to_square = False
do_resize = True
do_rescale = True
do_normalize = True
do_split_image = False
do_pad = False
ignore_index = None
valid_kwargs = EomtImageProcessorFastKwargs
def __init__(self, **kwargs: Unpack[EomtImageProcessorFastKwargs]):
super().__init__(**kwargs)
def _split_image(self, images: torch.Tensor, size: dict, image_indices: int) -> tuple[list, list]:
"""Slices an image into overlapping patches for semantic segmentation."""
patches, patch_offsets = [], []
_, _, height, width = images.shape
patch_size = size["shortest_edge"]
longer_side = max(height, width)
num_patches = math.ceil(longer_side / patch_size)
total_overlap = num_patches * patch_size - longer_side
overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0
for i in range(num_patches):
start = int(i * (patch_size - overlap_per_patch))
end = start + patch_size
if height > width:
batch_patch = images[:, :, start:end, :]
else:
batch_patch = images[:, :, :, start:end]
for batch_idx, single in enumerate(torch.unbind(batch_patch, dim=0)):
patches.append(single)
patch_offsets.append([image_indices[batch_idx], start, end])
return patches, patch_offsets
def _pad(self, images: torch.Tensor, size: dict) -> torch.Tensor:
"""Pads the image to the target size using zero padding."""
_, _, height, width = images.shape
target_height, target_width = get_target_size(size)
pad_h = max(0, target_height - height)
pad_w = max(0, target_width - width)
padding = (0, pad_w, 0, pad_h)
padded_images = torch.nn.functional.pad(images, padding, mode="constant", value=0.0)
return padded_images
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
do_split_image: bool,
do_pad: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
):
"""Preprocesses the input images and masks if provided."""
processed_images, patch_offsets = [], []
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for batched resizing, Needed in case do_resize is False.
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
original_indices = [
original_idx for original_idx, (img_shape, _) in grouped_images_index.items() if img_shape == shape
]
if do_split_image:
patches, offsets = self._split_image(stacked_images, size, original_indices)
processed_images.extend(patches)
patch_offsets.extend(offsets)
if do_pad:
stacked_images = self._pad(stacked_images, size)
processed_images_grouped[shape] = stacked_images
if do_split_image:
images, patch_offsets = reorder_patches_and_offsets(processed_images, patch_offsets)
if do_pad:
images = reorder_images(processed_images_grouped, grouped_images_index)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(images, dim=0) if return_tensors else images
return processed_images, patch_offsets
def _preprocess_images(self, images, **kwargs):
"""Preprocesses the input images."""
return self._preprocess(images, **kwargs)
def _preprocess_masks(self, segmentation_maps: list[torch.Tensor], **kwargs):
"""Preprocesses segmentation maps."""
processed_segmentation_maps = []
for segmentation_map in segmentation_maps:
segmentation_map = self._process_image(
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
)
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
processed_segmentation_maps.append(segmentation_map)
kwargs["do_normalize"] = False
kwargs["do_rescale"] = False
kwargs["input_data_format"] = ChannelDimension.FIRST
# Nearest interpolation is used for segmentation maps instead of BILINEAR.
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
processed_segmentation_maps, _ = self._preprocess(images=processed_segmentation_maps, **kwargs)
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
return processed_segmentation_maps
@auto_docstring
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[list[torch.Tensor]] = None,
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
**kwargs: Unpack[EomtImageProcessorFastKwargs],
) -> BatchFeature:
r"""
segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess for corresponding images.
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
A mapping between object instance ids and class ids.
"""
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self._valid_kwargs_names:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Prepare segmentation maps
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
# Check if resample is an int before checking if it's an instance of PILImageResampling
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample
)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
ignore_index = kwargs.pop("ignore_index", None)
processed_images, patch_offsets = self._preprocess_images(images=images, **kwargs)
outputs = BatchFeature({"pixel_values": processed_images})
mask_labels, class_labels = [], []
if segmentation_maps is not None:
segmentation_maps = self._preprocess_masks(segmentation_maps=segmentation_maps, **kwargs)
# Convert to list of binary masks and labels
for idx, segmentation_map in enumerate(segmentation_maps):
if isinstance(instance_id_to_semantic_id, list):
instance_id = instance_id_to_semantic_id[idx]
else:
instance_id = instance_id_to_semantic_id
# Use instance2class_id mapping per image
masks, classes = convert_segmentation_map_to_binary_masks(
segmentation_map,
instance_id,
ignore_index=ignore_index,
)
mask_labels.append(torch.from_numpy(masks))
class_labels.append(torch.from_numpy(classes))
# we cannot batch them since they don't share a common class size
outputs["mask_labels"] = mask_labels
outputs["class_labels"] = class_labels
if patch_offsets:
outputs["patch_offsets"] = patch_offsets
return outputs
def merge_image_patches(
self,
segmentation_logits: torch.Tensor,
patch_offsets: list[tuple[int, int, int]],
original_image_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""
Reconstructs full-size semantic segmentation logits from patch predictions.
Args:
segmentation_logits (`torch.Tensor`):
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
for each image patch.
patch_offsets (`List[Tuple[int, int, int]]`):
A list of tuples where each tuple contains:
- `image_index` (int): Index of the original image this patch belongs to.
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
- `end` (int): End pixel index of the patch along the long dimension.
original_image_sizes (`List[Tuple[int, int]]`):
List of original (height, width) dimensions for each image before preprocessing.
size (`Dict[str, int]`):
A size dict which was used to resize.
"""
num_classes = segmentation_logits.shape[1]
aggregated_logits = []
patch_counts = []
for image_size in original_image_sizes:
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
# Stitch patches back into full-sized logit maps
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
else:
aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, :, patch_start:patch_end] += 1
# Normalize and resize logits to original image size
reconstructed_logits = []
for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)):
averaged_logits = logit_sum / count.clamp(min=1)
resized_logits = torch.nn.functional.interpolate(
averaged_logits[None, ...],
size=original_image_sizes[idx],
mode="bilinear",
align_corners=False,
)[0]
reconstructed_logits.append(resized_logits)
return reconstructed_logits
def unpad_image(
self,
segmentation_logits: torch.Tensor,
original_image_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""Restores panoptic segmentation logits to their original image resolutions."""
resized_logits = []
for idx, original_size in enumerate(original_image_sizes):
target_height, target_width = get_size_with_aspect_ratio(
original_size, size["shortest_edge"], size["longest_edge"]
)
cropped_logits = segmentation_logits[idx][:, :target_height, :target_width]
upsampled_logits = torch.nn.functional.interpolate(
cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False
)[0]
resized_logits.append(upsampled_logits)
return resized_logits
def post_process_semantic_segmentation(
self,
outputs,
patch_offsets: list[tuple[int, int, int]],
original_image_sizes: list[tuple[int, int]],
size: Optional[dict[str, int]] = None,
) -> np.ndarray:
"""Post-processes model outputs into final semantic segmentation prediction."""
size = size if size is not None else self.size
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
output_size = get_target_size(size)
masks_queries_logits = torch.nn.functional.interpolate(
masks_queries_logits,
size=output_size,
mode="bilinear",
)
# Remove the null class `[..., :-1]`
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
preds = torch.stack(output_logits).argmax(dim=1)
return preds
def post_process_panoptic_segmentation(
self,
outputs,
original_image_sizes: list[tuple[int, int]],
threshold: float = 0.8,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
stuff_classes: Optional[list[int]] = None,
size: Optional[dict[str, int]] = None,
):
"""Post-processes model outputs into final panoptic segmentation prediction."""
size = size if size is not None else self.size
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
output_size = get_target_size(size)
masks_queries_logits = torch.nn.functional.interpolate(
masks_queries_logits,
size=output_size,
mode="bilinear",
)
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
results: list = []
for i in range(batch_size):
mask_probs, pred_scores, pred_labels = remove_low_and_no_objects(
mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels
)
# No mask found
if mask_probs.shape[0] <= 0:
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
segmentation, segments = compute_segments(
mask_probs=mask_probs,
pred_scores=pred_scores,
pred_labels=pred_labels,
stuff_classes=stuff_classes,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
def post_process_instance_segmentation(
self,
outputs,
original_image_sizes: list[tuple[int, int]],
threshold: float = 0.8,
size: Optional[dict[str, int]] = None,
):
"""Post-processes model outputs into Instance Segmentation Predictions."""
size = size if size is not None else self.size
masks_queries_logits = outputs.masks_queries_logits
class_queries_logits = outputs.class_queries_logits
output_size = get_target_size(size)
masks_queries_logits = torch.nn.functional.interpolate(
masks_queries_logits,
size=output_size,
mode="bilinear",
)
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
device = masks_queries_logits.device
batch_size = class_queries_logits.shape[0]
num_queries = class_queries_logits.shape[-2]
results = []
for i in range(batch_size):
mask_pred = mask_probs_batch[i]
mask_class = class_queries_logits[i]
# Remove the null class `[..., :-1]`
scores, pred_classes = mask_class.softmax(dim=-1)[..., :-1].max(-1)
pred_masks = (mask_pred > 0).float()
# Calculate average mask prob
mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
pred_masks.flatten(1).sum(1) + 1e-6
)
pred_scores = scores * mask_scores
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
instance_maps, segments = [], []
current_segment_id = 0
for j in range(num_queries):
score = pred_scores[j].item()
if not torch.all(pred_masks[j] == 0) and score >= threshold:
segmentation[pred_masks[j] == 1] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_classes[j].item(),
"score": round(score, 6),
}
)
current_segment_id += 1
instance_maps.append(pred_masks[j])
results.append({"segmentation": segmentation, "segments_info": segments})
return results
__all__ = ["EomtImageProcessorFast"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,588 @@
# coding=utf-8
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch EoMT model."""
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
auto_docstring,
can_return_tuple,
logging,
)
from ..dinov2.modeling_dinov2 import (
Dinov2Embeddings,
Dinov2Layer,
Dinov2LayerScale,
Dinov2PatchEmbeddings,
)
from ..mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentation, Mask2FormerLoss
from ..siglip.modeling_siglip import SiglipAttention
from ..vit.configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
class EomtConfig(ViTConfig):
r"""
This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the EoMT
[tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the hidden representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads in each attention layer.
mlp_ratio (`int`, *optional*, defaults to 4):
Ratio of the MLP hidden dimensionality to the hidden size.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings and encoder.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to 640):
The size (resolution) of each input image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
layerscale_value (`float`, *optional*, defaults to 1.0):
Initial value for the LayerScale parameter.
drop_path_rate (`float`, *optional*, defaults to 0.0):
The stochastic depth rate (drop path) used during training.
num_upscale_blocks (`int`, *optional*, defaults to 2):
Number of upsampling blocks used in the decoder or segmentation head.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability applied after attention projection.
use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
Whether to use the SwiGLU feedforward neural network.
num_blocks (`int`, *optional*, defaults to 4):
Number of feature blocks or stages in the architecture.
no_object_weight (`float`, *optional*, defaults to 0.1):
Loss weight for the 'no object' class in panoptic/instance segmentation.
class_weight (`float`, *optional*, defaults to 2.0):
Loss weight for classification targets.
mask_weight (`float`, *optional*, defaults to 5.0):
Loss weight for mask prediction.
dice_weight (`float`, *optional*, defaults to 5.0):
Loss weight for the dice loss component.
train_num_points (`int`, *optional*, defaults to 12544):
Number of points to sample for mask loss computation during training.
oversample_ratio (`float`, *optional*, defaults to 3.0):
Oversampling ratio used in point sampling for mask training.
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
Ratio of points to sample based on importance during training.
num_queries (`int`, *optional*, defaults to 200):
Number of object queries in the Transformer.
num_register_tokens (`int`, *optional*, defaults to 4):
Number of learnable register tokens added to the transformer input.
Example:
```python
>>> from transformers import EomtConfig, EomtForUniversalSegmentation
>>> # Initialize configuration
>>> config = EomtConfig()
>>> # Initialize model
>>> model = EomtForUniversalSegmentation(config)
>>> # Access config
>>> config = model.config
```"""
model_type = "eomt"
def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
mlp_ratio=4,
hidden_act="gelu",
hidden_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-6,
image_size=640,
patch_size=16,
num_channels=3,
layerscale_value=1.0,
drop_path_rate=0.0,
num_upscale_blocks=2,
attention_dropout=0.0,
use_swiglu_ffn=False,
num_blocks=4,
no_object_weight: float = 0.1,
class_weight: float = 2.0,
mask_weight: float = 5.0,
dice_weight: float = 5.0,
train_num_points: int = 12544,
oversample_ratio: float = 3.0,
importance_sample_ratio: float = 0.75,
num_queries=200,
num_register_tokens=4,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
image_size=image_size,
patch_size=patch_size,
num_channels=num_channels,
**kwargs,
)
del self.intermediate_size
del self.qkv_bias
del self.pooler_act
del self.pooler_output_size
del self.encoder_stride
del self.attention_probs_dropout_prob
self.mlp_ratio = mlp_ratio
self.attention_dropout = attention_dropout
self.layerscale_value = layerscale_value
self.drop_path_rate = drop_path_rate
self.num_upscale_blocks = num_upscale_blocks
self.use_swiglu_ffn = use_swiglu_ffn
self.num_blocks = num_blocks
self.no_object_weight = no_object_weight
self.class_weight = class_weight
self.mask_weight = mask_weight
self.dice_weight = dice_weight
self.train_num_points = train_num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.num_queries = num_queries
self.num_register_tokens = num_register_tokens
@dataclass
@auto_docstring(
custom_intro="""
Class for outputs of [`EomtForUniversalSegmentationOutput`].
This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or
[`~EomtImageProcessor.post_process_instance_segmentation`] or
[`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
[`~EomtImageProcessor] for details regarding usage.
"""
)
class EomtForUniversalSegmentationOutput(ModelOutput):
r"""
loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last layer.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: Optional[torch.FloatTensor] = None
masks_queries_logits: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
class EomtLoss(Mask2FormerLoss):
pass
class EomtPatchEmbeddings(Dinov2PatchEmbeddings):
pass
class EomtEmbeddings(Dinov2Embeddings, nn.Module):
def __init__(self, config: EomtConfig) -> None:
Dinov2Embeddings().__init__()
self.config = config
self.patch_size = config.patch_size
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
self.patch_embeddings = EomtPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
def interpolate_pos_encoding(self):
raise AttributeError("Not needed for Eomt Model")
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, _, _, _ = pixel_values.shape
target_dtype = self.patch_embeddings.projection.weight.dtype
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
embeddings = embeddings + self.position_embeddings(self.position_ids)
embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
embeddings = self.dropout(embeddings)
return embeddings
class EomtAttention(SiglipAttention):
pass
class EomtLayerScale(Dinov2LayerScale):
pass
class EomtLayer(Dinov2Layer):
pass
class EomtLayerNorm2d(nn.LayerNorm):
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = hidden_state.permute(0, 2, 3, 1)
hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
hidden_state = hidden_state.permute(0, 3, 1, 2)
return hidden_state
class EomtScaleLayer(nn.Module):
def __init__(self, config: EomtConfig):
super().__init__()
hidden_size = config.hidden_size
self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
self.activation = ACT2FN[config.hidden_act]
self.conv2 = nn.Conv2d(
hidden_size,
hidden_size,
kernel_size=3,
padding=1,
groups=hidden_size,
bias=False,
)
self.layernorm2d = EomtLayerNorm2d(hidden_size)
def forward(self, hidden_states: torch.tensor) -> torch.Tensor:
hidden_states = self.conv1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layernorm2d(hidden_states)
return hidden_states
class EomtScaleBlock(nn.Module):
def __init__(self, config: EomtConfig):
super().__init__()
self.num_blocks = config.num_upscale_blocks
self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for block in self.block:
hidden_states = block(hidden_states)
return hidden_states
class EomtMaskHead(nn.Module):
def __init__(self, config: EomtConfig):
super().__init__()
hidden_size = config.hidden_size
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, hidden_size)
self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.activation(self.fc1(hidden_states))
hidden_states = self.activation(self.fc2(hidden_states))
hidden_states = self.fc3(hidden_states)
return hidden_states
@auto_docstring
class EomtPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = EomtConfig
base_model_prefix = "eomt"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_no_split_modules = ["EomtMLP"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: nn.Module) -> None:
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=1)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, EomtLayerScale):
if hasattr(module, "lambda1"):
module.lambda1.data.fill_(self.config.layerscale_value)
elif isinstance(module, EomtEmbeddings):
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32), mean=0.0, std=std
).to(module.cls_token.dtype)
module.register_tokens.data.zero_()
@auto_docstring(
custom_intro="""
The EoMT Model with head on top for instance/semantic/panoptic segmentation.
"""
)
class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Module):
def __init__(self, config: EomtConfig) -> None:
nn.Module().__init__(config)
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = EomtEmbeddings(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.query = nn.Embedding(config.num_queries, config.hidden_size)
self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)])
self.upscale_block = EomtScaleBlock(config)
self.mask_head = EomtMaskHead(config)
self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
self.weight_dict: dict[str, float] = {
"loss_cross_entropy": config.class_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
}
self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict)
self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def get_auxiliary_logits(self):
raise AttributeError("Note needed for Eomt Model.")
def predict(self, logits: torch.Tensor):
query_tokens = logits[:, : self.config.num_queries, :]
class_logits = self.class_predictor(query_tokens)
prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
prefix_tokens = prefix_tokens.transpose(1, 2)
prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
query_tokens = self.mask_head(query_tokens)
prefix_tokens = self.upscale_block(prefix_tokens)
mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
return mask_logits, class_logits
@staticmethod
def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
if prob < 1:
# Generate random queries to disable based on the probs
random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
# Disable attention to the query tokens, considering the prefix tokens
attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
return attn_mask
@auto_docstring
@can_return_tuple
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[list[Tensor]] = None,
class_labels: Optional[list[Tensor]] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
):
r"""
mask_labels (`List[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
attention_mask = None
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
for idx, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks:
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and (
self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
):
norm_hidden_states = self.layernorm(hidden_states)
masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
masks_queries_logits_per_layer += (masks_queries_logits,)
class_queries_logits_per_layer += (class_queries_logits,)
attention_mask = torch.ones(
hidden_states.shape[0],
hidden_states.shape[1],
hidden_states.shape[1],
device=hidden_states.device,
dtype=torch.bool,
)
interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
interpolated_logits = interpolated_logits.view(
interpolated_logits.size(0), interpolated_logits.size(1), -1
)
num_query_tokens = self.config.num_queries
encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens
# Set attention mask for queries to focus on encoder tokens based on interpolated logits
attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
# Disable attention mask for random query tokens.
attention_mask = self._disable_attention_mask(
attention_mask,
prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
num_query_tokens=num_query_tokens,
encoder_start_tokens=encoder_start_tokens,
device=attention_mask.device,
)
# Expand attention mask to 4d mask.
attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9)
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
sequence_output = self.layernorm(hidden_states)
if output_hidden_states:
all_hidden_states += (sequence_output,)
masks_queries_logits, class_queries_logits = self.predict(sequence_output)
masks_queries_logits_per_layer += (masks_queries_logits,)
class_queries_logits_per_layer += (class_queries_logits,)
loss = None
if mask_labels is not None and class_labels is not None:
loss = 0.0
for masks_queries_logits, class_queries_logits in zip(
masks_queries_logits_per_layer, class_queries_logits_per_layer
):
loss_dict = self.get_loss_dict(
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
mask_labels=mask_labels,
class_labels=class_labels,
auxiliary_predictions=None,
)
loss += self.get_loss(loss_dict)
return EomtForUniversalSegmentationOutput(
loss=loss,
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
last_hidden_state=sequence_output,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
__all__ = ["EomtConfig", "EomtPreTrainedModel", "EomtForUniversalSegmentation"]

View File

@@ -512,7 +512,7 @@ class Mask2FormerLoss(nn.Module):
self.importance_sample_ratio = config.importance_sample_ratio self.importance_sample_ratio = config.importance_sample_ratio
self.matcher = Mask2FormerHungarianMatcher( self.matcher = Mask2FormerHungarianMatcher(
cost_class=1.0, cost_class=config.class_weight,
cost_dice=config.dice_weight, cost_dice=config.dice_weight,
cost_mask=config.mask_weight, cost_mask=config.mask_weight,
num_points=self.num_points, num_points=self.num_points,

View File

View File

@@ -0,0 +1,308 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch EoMT Image Processor."""
import unittest
import numpy as np
import requests
from datasets import load_dataset
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import EomtImageProcessor
if is_torchvision_available():
from transformers import EomtImageProcessorFast
from transformers.models.eomt.modeling_eomt import EomtForUniversalSegmentationOutput
class EomtImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
min_resolution=30,
max_resolution=400,
size=None,
do_resize=True,
do_pad=True,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
num_labels=10,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.do_pad = do_pad
self.size = size if size is not None else {"shortest_edge": 18, "longest_edge": 18}
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
# for the post_process_functions
self.batch_size = 2
self.num_queries = 3
self.num_classes = 2
self.height = 18
self.width = 18
self.num_labels = num_labels
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_pad": self.do_pad,
"num_labels": self.num_labels,
}
def prepare_fake_eomt_outputs(self, batch_size):
return EomtForUniversalSegmentationOutput(
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
)
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
def prepare_semantic_single_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
example = ds[0]
return example["image"], example["map"]
def prepare_semantic_batch_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
return list(ds["image"][:2]), list(ds["map"][:2])
@require_torch
@require_vision
class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = EomtImageProcessor if is_vision_available() else None
fast_image_processing_class = EomtImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = EomtImageProcessingTester(self)
self.model_id = "tue-mps/coco_panoptic_eomt_large_640"
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "resample"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 18, "longest_edge": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (2, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip(reason="Not supported")
def test_call_numpy_4_channels(self):
pass
def test_call_pil(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test Non batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (2, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (2, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image, dummy_map = prepare_semantic_single_inputs()
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
)
# Lets check whether 99.9% of mask_labels values match or not.
match_ratio = (image_encoding_slow.mask_labels[0] == image_encoding_fast.mask_labels[0]).float().mean().item()
self.assertGreaterEqual(match_ratio, 0.999, "Mask labels do not match between slow and fast image processor.")
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
for idx in range(len(dummy_maps)):
match_ratio = (encoding_slow.mask_labels[idx] == encoding_fast.mask_labels[idx]).float().mean().item()
self.assertGreaterEqual(
match_ratio, 0.999, "Mask labels do not match between slow and fast image processors."
)
def test_post_process_semantic_segmentation(self):
processor = self.image_processing_class(**self.image_processor_dict)
# Set longest_edge to None to test for semantic segmentatiom.
processor.size = {"shortest_edge": 18, "longest_edge": None}
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, do_split_image=True, return_tensors="pt")
patch_offsets = inputs.pop("patch_offsets")
original_sizes = [image.size[::-1]]
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0])
segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes)
self.assertEqual(segmentation[0].shape, (image.height, image.width))
def test_post_process_panoptic_segmentation(self):
processor = self.image_processing_class(**self.image_processor_dict)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
original_sizes = [image.size[::-1], image.size[::-1]]
# lets test for batched input of 2
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(2)
segmentation = processor.post_process_panoptic_segmentation(outputs, original_sizes)
self.assertTrue(len(segmentation) == 2)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(el["segmentation"].shape, (image.height, image.width))
def test_post_process_instance_segmentation(self):
processor = self.image_processing_class(**self.image_processor_dict)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
original_sizes = [image.size[::-1], image.size[::-1]]
# lets test for batched input of 2
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(2)
segmentation = processor.post_process_instance_segmentation(outputs, original_sizes)
self.assertTrue(len(segmentation) == 2)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(el["segmentation"].shape, (image.height, image.width))

View File

@@ -0,0 +1,475 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch EoMT model."""
import unittest
import requests
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
class EomtForUniversalSegmentationTester:
def __init__(
self,
parent,
batch_size=2,
is_training=True,
image_size=40,
patch_size=2,
num_queries=5,
num_register_tokens=19,
num_labels=4,
hidden_size=8,
num_attention_heads=2,
num_hidden_layers=4,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.num_queries = num_queries
self.image_size = image_size
self.patch_size = patch_size
self.num_labels = num_labels
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_register_tokens = num_register_tokens
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
def get_config(self):
config = {
"image_size": self.image_size,
"patch_size": self.patch_size,
"num_labels": self.num_labels,
"hidden_size": self.hidden_size,
"num_attention_heads": self.num_attention_heads,
"num_hidden_layers": self.num_hidden_layers,
"num_register_tokens": self.num_register_tokens,
"num_queries": self.num_queries,
"num_blocks": 1,
}
return EomtConfig(**config)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size]).to(torch_device)
mask_labels = (
torch.rand([self.batch_size, self.num_labels, self.image_size, self.image_size], device=torch_device) > 0.5
).float()
class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long()
config = self.get_config()
return config, pixel_values, mask_labels, class_labels
def prepare_config_and_inputs_for_common(self):
config, pixel_values, mask_labels, class_labels = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
def prepare_config_and_inputs_for_training(self):
config, pixel_values, mask_labels, class_labels = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values, "mask_labels": mask_labels, "class_labels": class_labels}
return config, inputs_dict
@require_torch
class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
is_encoder_decoder = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torch_exportable = False
def setUp(self):
self.model_tester = EomtForUniversalSegmentationTester(self)
self.config_tester = ConfigTester(self, config_class=EomtConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_model_with_labels(self):
size = (self.model_tester.image_size,) * 2
inputs = {
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
"mask_labels": torch.randn((2, 10, *size), device=torch_device),
"class_labels": torch.zeros(2, 10, device=torch_device).long(),
}
config = self.model_tester.get_config()
model = EomtForUniversalSegmentation(config).to(torch_device)
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class._from_config(config, attn_implementation="eager")
config = model.config
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# Check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
@unittest.skip(reason="EoMT does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="EoMT does not have a get_input_embeddings method")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="EoMT is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="EoMT does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
def test_training(self):
if not self.model_tester.is_training:
self.skipTest(reason="ModelTester is not configured to run training tests")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_training()
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
# Apart from the below params, all other parameters are initialized using kaiming uniform.
non_uniform_init_parms = [
"layernorm.bias",
"layernorm.weight",
"norm1.bias",
"norm1.weight",
"norm2.bias",
"norm2.weight",
"layer_scale1.lambda1",
"layer_scale2.lambda1",
"register_tokens",
"cls_token",
]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if any(x in name for x in non_uniform_init_parms):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@require_torch
class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "tue-mps/coco_panoptic_eomt_large_640"
@slow
def test_inference(self):
model = EomtForUniversalSegmentation.from_pretrained(self.model_id, device_map="auto")
processor = AutoImageProcessor.from_pretrained(self.model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model(**inputs)
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
# fmt: off
EXPECTED_SLICE = torch.tensor([
[ 13.2540, 8.9279, 8.6631, 12.3760, 10.1429],
[ -3.4815, -36.4630, -45.5604, -46.8404, -37.5099],
[ -6.8689, -44.4206, -62.7591, -59.2928, -47.7035],
[ -2.9380, -42.0659, -57.4382, -55.1537, -43.5142],
[ -8.4387, -38.5275, -53.1383, -47.0064, -38.9667],
]).to(model.device)
# fmt: on
output_slice = outputs.masks_queries_logits[0, 0, :5, :5]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
# fmt: off
EXPECTED_SLICE = torch.tensor([
[-0.6977, -6.4907, -4.1178, -6.5554, -6.6529],
[-0.3650, -6.6560, -4.0143, -6.5776, -6.5879],
[-0.8820, -6.7175, -3.5334, -6.8569, -6.2415],
[ 0.4502, -5.3911, -3.0232, -5.9411, -6.3243],
[ 0.3157, -5.6321, -2.6716, -5.5740, -5.5607],
]).to(model.device)
# fmt: on
output_slice = outputs.class_queries_logits[0, :5, :5]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
@require_torch_accelerator
@require_torch_fp16
@slow
def test_inference_fp16(self):
model = EomtForUniversalSegmentation.from_pretrained(
self.model_id, torch_dtype=torch.float16, device_map="auto"
)
processor = AutoImageProcessor.from_pretrained(self.model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model(**inputs)
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
@slow
def test_semantic_segmentation_inference(self):
model_id = "tue-mps/ade20k_semantic_eomt_large_512"
model = EomtForUniversalSegmentation.from_pretrained(model_id, device_map="auto")
processor = AutoImageProcessor.from_pretrained(model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device)
patch_offsets = inputs.pop("patch_offsets", None)
with torch.inference_mode():
outputs = model(**inputs)
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
preds = processor.post_process_semantic_segmentation(
outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets
)
self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0]))
# fmt: off
EXPECTED_SLICE = torch.tensor([
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39]
], device=model.device)
# fmt: on
output_slice = preds[0, :10, :10]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
@slow
def test_panoptic_segmentation_inference(self):
model = EomtForUniversalSegmentation.from_pretrained(self.model_id, device_map="auto")
processor = AutoImageProcessor.from_pretrained(self.model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model(**inputs)
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
preds = processor.post_process_panoptic_segmentation(
outputs, original_image_sizes=[(image.size[1], image.size[0])]
)[0]
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
# fmt: off
EXPECTED_SLICE = torch.tensor([
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, 2, 2, 2, 2, 2],
[-1, -1, -1, 2, 2, 2, 2, 2, 2, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
], device=model.device)
EXPECTED_SEGMENTS_INFO = [
{"id": 0, "label_id": 15, "score": 0.99935},
{"id": 1, "label_id": 15, "score": 0.998688},
{"id": 2, "label_id": 57, "score": 0.954325},
{"id": 3, "label_id": 65, "score": 0.997285},
{"id": 4, "label_id": 65, "score": 0.99711}
]
# fmt: on
output_slice = segmentation[:10, :10]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
for actual, expected in zip(segments_info, EXPECTED_SEGMENTS_INFO):
self.assertEqual(actual["id"], expected["id"])
self.assertEqual(actual["label_id"], expected["label_id"])
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
@slow
def test_instance_segmentation_inference(self):
model_id = "tue-mps/coco_instance_eomt_large_640"
model = EomtForUniversalSegmentation.from_pretrained(model_id, device_map="auto")
processor = AutoImageProcessor.from_pretrained(model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model(**inputs)
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
preds = processor.post_process_instance_segmentation(
outputs, original_image_sizes=[(image.size[1], image.size[0])]
)[0]
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
# fmt: off
EXPECTED_SLICE = torch.tensor([
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., 0., 0., 1., 1., 1., 1., 1.],
[ 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]
], device=model.device)
EXPECTED_SEGMENTS_INFO = [
{'id': 0, 'label_id': 57, 'score': 0.871247},
{'id': 1, 'label_id': 57, 'score': 0.821225},
{'id': 2, 'label_id': 15, 'score': 0.976252},
{'id': 3, 'label_id': 65, 'score': 0.972960},
{'id': 4, 'label_id': 65, 'score': 0.981109},
{'id': 5, 'label_id': 15, 'score': 0.972689}
]
# fmt: on
output_slice = segmentation[:10, :10]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
for actual, expected in zip(segments_info, EXPECTED_SEGMENTS_INFO):
self.assertEqual(actual["id"], expected["id"])
self.assertEqual(actual["label_id"], expected["label_id"])
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)