✨ Add EoMT Model || 🚨 Fix Mask2Former loss calculation (#37610)
* Initial Commit * up * More changes * up * Only mask_logits mismatch * close enough logits debug later * fixes * format * Add dummy loss * Close enough processing for semantic seg * nit * Added panoptic postprocessor * refactor * refactor * finally fixed panoptic postprocessor * temp update * Refactor ForUniversalSegmentation class * nits and config update * Few fixes and inference matches * change mapping * Added training support but loss slightly off 🥲 * Loss is matching 😀 * update * Initial tests skelton * changes * tests update * more modular * initial tests * updates * better docstrings * changes * proc tests passing :) * Image processor update * tiny change * QOL changes * Update test w.r.t latest attn refactor * repo-consistency fixes * up * Image proc fix and integration tests :) * docs update * integration tests * fix * docs update 🥰 * minor fix * Happy CI * fix * obvious refactoring * refactoring w.r.t review * Add fask image proc skelton * Fast Image proc and cleanups * Use more modular * tests update * Add more tests * Nit * QOL updates * change init_weights to torch default * add eager func coz of make style * up * changes * typo fix * Updates * More deterministic tests * More modular * go more modular 🚀 * up * dump * add supprot for giant ckpts * overhaul * modular * refactor * instace seg is ready * cleanup * forgot this * docs cleanup * minor changes * EoMT - > Eomt * Happy CI * remove redundant comment * Change model references * final change * check annealing per block * My other PR changes 😂 --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -737,6 +737,8 @@
|
||||
title: EfficientFormer
|
||||
- local: model_doc/efficientnet
|
||||
title: EfficientNet
|
||||
- local: model_doc/eomt
|
||||
title: EoMT
|
||||
- local: model_doc/focalnet
|
||||
title: FocalNet
|
||||
- local: model_doc/glpn
|
||||
|
||||
214
docs/source/en/model_doc/eomt.md
Normal file
214
docs/source/en/model_doc/eomt.md
Normal file
@@ -0,0 +1,214 @@
|
||||
<!--Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# EoMT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The Encoder-only Mask Transformer (EoMT) model was introduced in the CVPR 2025 Highlight Paper [Your ViT is Secretly an Image Segmentation Model](https://www.tue-mps.org/eomt) by Tommie Kerssies, Niccolò Cavagnero, Alexander Hermans, Narges Norouzi, Giuseppe Averta, Bastian Leibe, Gijs Dubbelman, and Daan de Geus.
|
||||
EoMT reveals Vision Transformers can perform image segmentation efficiently without task-specific components.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Vision Transformers (ViTs) have shown remarkable performance and scalability across various computer vision tasks. To apply single-scale ViTs to image segmentation, existing methods adopt a convolutional adapter to generate multi-scale features, a pixel decoder to fuse these features, and a Transformer decoder that uses the fused features to make predictions. In this paper, we show that the inductive biases introduced by these task-specific components can instead be learned by the ViT itself, given sufficiently large models and extensive pre-training. Based on these findings, we introduce the Encoder-only Mask Transformer (EoMT), which repurposes the plain ViT architecture to conduct image segmentation. With large-scale models and pre-training, EoMT obtains a segmentation accuracy similar to state-of-the-art models that use task-specific components. At the same time, EoMT is significantly faster than these methods due to its architectural simplicity, e.g., up to 4x faster with ViT-L. Across a range of model sizes, EoMT demonstrates an optimal balance between segmentation accuracy and prediction speed, suggesting that compute resources are better spent on scaling the ViT itself rather than adding architectural complexity.*
|
||||
|
||||
This model was contributed by [Yaswanth Gali](https://huggingface.co/yaswanthgali).
|
||||
The original code can be found [here](https://github.com/tue-mps/eomt).
|
||||
|
||||
## Architecture Info
|
||||
|
||||
The `EoMT` model uses a DINOv2-pretrained Vision Transformer with **register tokens** as its backbone. EoMT simplifies the segmentation pipeline by relying solely on the encoder, eliminating the need for task-specific decoders commonly used in prior approaches.
|
||||
|
||||
Architecturally, EoMT introduces a small set of **learned queries** and a lightweight **mask prediction module**. These queries are injected into the final encoder blocks, enabling **joint attention** between image patches and object queries. During training, **masked attention** is applied to constrain each query to focus on its corresponding region—effectively mimicking cross-attention. This constraint is gradually phased out via a **mask annealing strategy**, allowing for **efficient, decoder-free inference** without compromising segmentation performance.
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/eomt_architecture.png"
|
||||
alt="drawing" width="500"/>
|
||||
</div>
|
||||
|
||||
|
||||
The model supports semantic, instance, and panoptic segmentation using a unified architecture and task-specific post-processing.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
Use the Hugging Face implementation of EoMT for inference with pre-trained models.
|
||||
|
||||
### Semantic Segmentation
|
||||
|
||||
The EoMT model performs semantic segmentation using sliding-window inference. The input image is resized such that the shorter side matches the target input size, then it is split into overlapping crops. Each crop is then passed through the model. After inference, the predicted logits from each crop are stitched back together and rescaled to the original image size to get the final segmentation mask.
|
||||
|
||||
> **Note:**
|
||||
> If you want to use a custom target size for **semantic segmentation**, specify it in the following format:
|
||||
> `{"shortest_edge": 512}`
|
||||
> Notice that `longest_edge` is not provided here — this is intentional. For semantic segmentation, images are typically **scaled so that the shortest edge is greater than or equal to the target size** hence longest_edge is not necessary.
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
|
||||
|
||||
|
||||
model_id = "tue-mps/ade20k_semantic_eomt_large_512"
|
||||
processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
model = EomtForUniversalSegmentation.from_pretrained(model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Remove Patch Offsets from inputs — only used later for post-processing.
|
||||
patch_offsets = inputs.pop("patch_offsets")
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Prepare the original image size in the format (height, width)
|
||||
original_image_sizes = [(image.height, image.width)]
|
||||
|
||||
# Post-process the model outputs to get final segmentation prediction
|
||||
preds = processor.post_process_semantic_segmentation(
|
||||
outputs,
|
||||
patch_offsets=patch_offsets,
|
||||
original_image_sizes=original_image_sizes,
|
||||
)
|
||||
|
||||
# Visualize the segmentation mask
|
||||
plt.imshow(preds[0])
|
||||
plt.axis("off")
|
||||
plt.title("Semantic Segmentation")
|
||||
plt.show()
|
||||
```
|
||||
|
||||
### Instance Segmentation
|
||||
|
||||
The EoMT model performs instance segmentation using padded inference. The input image is resized so that the longer side matches the target input size, and the shorter side is zero-padded to form a square. The resulting mask and class logits are combined through post-processing (adapted from Mask2Former) to produce a unified instance segmentation map, along with segment metadata like segment id, class labels and confidence scores.
|
||||
|
||||
> **Note:**
|
||||
> To use a custom target size, specify the size as a dictionary in the following format:
|
||||
> `{"shortest_edge": 512, "longest_edge": 512}`
|
||||
> For both instance and panoptic segmentation, input images will be **scaled and padded** to this target size.
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
|
||||
|
||||
|
||||
model_id = "tue-mps/coco_instance_eomt_large_640"
|
||||
processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
model = EomtForUniversalSegmentation.from_pretrained(model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Prepare the original image size in the format (height, width)
|
||||
original_image_sizes = [(image.height, image.width)]
|
||||
|
||||
# Post-process the model outputs to get final segmentation prediction
|
||||
preds = processor.post_process_instance_segmentation(
|
||||
outputs,
|
||||
original_image_sizes=original_image_sizes,
|
||||
)
|
||||
|
||||
# Visualize the segmentation mask
|
||||
plt.imshow(preds[0]["segmentation"])
|
||||
plt.axis("off")
|
||||
plt.title("Instance Segmentation")
|
||||
plt.show()
|
||||
```
|
||||
|
||||
### Panoptic Segmentation
|
||||
|
||||
The EoMT model performs panoptic segmentation using the same padded inference strategy as in instance segmentation. After padding and normalization, the model predicts both thing (instances) and stuff (amorphous regions) classes. The resulting mask and class logits are combined through post-processing (adapted from Mask2Former) to produce a unified panoptic segmentation map, along with segment metadata like segment id, class labels and confidence scores.
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
|
||||
|
||||
|
||||
model_id = "tue-mps/coco_panoptic_eomt_large_640"
|
||||
processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
model = EomtForUniversalSegmentation.from_pretrained(model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Prepare the original image size in the format (height, width)
|
||||
original_image_sizes = [(image.height, image.width)]
|
||||
|
||||
# Post-process the model outputs to get final segmentation prediction
|
||||
preds = processor.post_process_panoptic_segmentation(
|
||||
outputs,
|
||||
original_image_sizes=original_image_sizes,
|
||||
)
|
||||
|
||||
# Visualize the panoptic segmentation mask
|
||||
plt.imshow(preds[0]["segmentation"])
|
||||
plt.axis("off")
|
||||
plt.title("Panoptic Segmentation")
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## EomtImageProcessor
|
||||
|
||||
[[autodoc]] EomtImageProcessor
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
- post_process_instance_segmentation
|
||||
- post_process_panoptic_segmentation
|
||||
|
||||
## EomtImageProcessorFast
|
||||
|
||||
[[autodoc]] EomtImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
- post_process_instance_segmentation
|
||||
- post_process_panoptic_segmentation
|
||||
|
||||
## EomtConfig
|
||||
|
||||
[[autodoc]] EomtConfig
|
||||
|
||||
## EomtForUniversalSegmentation
|
||||
|
||||
[[autodoc]] EomtForUniversalSegmentation
|
||||
- forward
|
||||
@@ -122,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("emu3", "Emu3Config"),
|
||||
("encodec", "EncodecConfig"),
|
||||
("encoder-decoder", "EncoderDecoderConfig"),
|
||||
("eomt", "EomtConfig"),
|
||||
("ernie", "ErnieConfig"),
|
||||
("ernie_m", "ErnieMConfig"),
|
||||
("esm", "EsmConfig"),
|
||||
@@ -501,6 +502,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("emu3", "Emu3"),
|
||||
("encodec", "EnCodec"),
|
||||
("encoder-decoder", "Encoder decoder"),
|
||||
("eomt", "EoMT"),
|
||||
("ernie", "ERNIE"),
|
||||
("ernie_m", "ErnieM"),
|
||||
("esm", "ESM"),
|
||||
|
||||
@@ -84,6 +84,7 @@ else:
|
||||
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
||||
("efficientformer", ("EfficientFormerImageProcessor",)),
|
||||
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
|
||||
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
||||
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("fuyu", ("FuyuImageProcessor",)),
|
||||
|
||||
@@ -854,6 +854,7 @@ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Universal Segmentation mapping
|
||||
("detr", "DetrForSegmentation"),
|
||||
("eomt", "EomtForUniversalSegmentation"),
|
||||
("mask2former", "Mask2FormerForUniversalSegmentation"),
|
||||
("maskformer", "MaskFormerForInstanceSegmentation"),
|
||||
("oneformer", "OneFormerForUniversalSegmentation"),
|
||||
|
||||
29
src/transformers/models/eomt/__init__.py
Normal file
29
src/transformers/models/eomt/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_eomt import *
|
||||
from .image_processing_eomt import *
|
||||
from .image_processing_eomt_fast import *
|
||||
from .modeling_eomt import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
168
src/transformers/models/eomt/configuration_eomt.py
Normal file
168
src/transformers/models/eomt/configuration_eomt.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/eomt/modular_eomt.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_eomt.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class EomtConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the EoMT
|
||||
[tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640)
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the hidden representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads in each attention layer.
|
||||
mlp_ratio (`int`, *optional*, defaults to 4):
|
||||
Ratio of the MLP hidden dimensionality to the hidden size.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings and encoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
image_size (`int`, *optional*, defaults to 640):
|
||||
The size (resolution) of each input image.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
layerscale_value (`float`, *optional*, defaults to 1.0):
|
||||
Initial value for the LayerScale parameter.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
The stochastic depth rate (drop path) used during training.
|
||||
num_upscale_blocks (`int`, *optional*, defaults to 2):
|
||||
Number of upsampling blocks used in the decoder or segmentation head.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability applied after attention projection.
|
||||
use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the SwiGLU feedforward neural network.
|
||||
num_blocks (`int`, *optional*, defaults to 4):
|
||||
Number of feature blocks or stages in the architecture.
|
||||
no_object_weight (`float`, *optional*, defaults to 0.1):
|
||||
Loss weight for the 'no object' class in panoptic/instance segmentation.
|
||||
class_weight (`float`, *optional*, defaults to 2.0):
|
||||
Loss weight for classification targets.
|
||||
mask_weight (`float`, *optional*, defaults to 5.0):
|
||||
Loss weight for mask prediction.
|
||||
dice_weight (`float`, *optional*, defaults to 5.0):
|
||||
Loss weight for the dice loss component.
|
||||
train_num_points (`int`, *optional*, defaults to 12544):
|
||||
Number of points to sample for mask loss computation during training.
|
||||
oversample_ratio (`float`, *optional*, defaults to 3.0):
|
||||
Oversampling ratio used in point sampling for mask training.
|
||||
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
|
||||
Ratio of points to sample based on importance during training.
|
||||
num_queries (`int`, *optional*, defaults to 200):
|
||||
Number of object queries in the Transformer.
|
||||
num_register_tokens (`int`, *optional*, defaults to 4):
|
||||
Number of learnable register tokens added to the transformer input.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import EomtConfig, EomtForUniversalSegmentation
|
||||
|
||||
>>> # Initialize configuration
|
||||
>>> config = EomtConfig()
|
||||
|
||||
>>> # Initialize model
|
||||
>>> model = EomtForUniversalSegmentation(config)
|
||||
|
||||
>>> # Access config
|
||||
>>> config = model.config
|
||||
```"""
|
||||
|
||||
model_type = "eomt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
mlp_ratio=4,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-6,
|
||||
image_size=640,
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
layerscale_value=1.0,
|
||||
drop_path_rate=0.0,
|
||||
num_upscale_blocks=2,
|
||||
attention_dropout=0.0,
|
||||
use_swiglu_ffn=False,
|
||||
num_blocks=4,
|
||||
no_object_weight: float = 0.1,
|
||||
class_weight: float = 2.0,
|
||||
mask_weight: float = 5.0,
|
||||
dice_weight: float = 5.0,
|
||||
train_num_points: int = 12544,
|
||||
oversample_ratio: float = 3.0,
|
||||
importance_sample_ratio: float = 0.75,
|
||||
num_queries=200,
|
||||
num_register_tokens=4,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layerscale_value = layerscale_value
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_upscale_blocks = num_upscale_blocks
|
||||
self.use_swiglu_ffn = use_swiglu_ffn
|
||||
self.num_blocks = num_blocks
|
||||
self.no_object_weight = no_object_weight
|
||||
self.class_weight = class_weight
|
||||
self.mask_weight = mask_weight
|
||||
self.dice_weight = dice_weight
|
||||
self.train_num_points = train_num_points
|
||||
self.oversample_ratio = oversample_ratio
|
||||
self.importance_sample_ratio = importance_sample_ratio
|
||||
self.num_queries = num_queries
|
||||
self.num_register_tokens = num_register_tokens
|
||||
|
||||
|
||||
__all__ = ["EomtConfig"]
|
||||
340
src/transformers/models/eomt/convert_eomt_to_hf.py
Normal file
340
src/transformers/models/eomt/convert_eomt_to_hf.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from transformers import EomtConfig, EomtForUniversalSegmentation, EomtImageProcessorFast
|
||||
|
||||
|
||||
# fmt: off
|
||||
MAPPINGS = {
|
||||
# Embeddings
|
||||
r"network.encoder.backbone.cls_token" : r"embeddings.cls_token",
|
||||
r"network.encoder.backbone.reg_token" : r"embeddings.register_tokens",
|
||||
r"network.encoder.backbone.pos_embed" : r"embeddings.position_embeddings.weight",
|
||||
r"network.encoder.backbone.patch_embed.proj" : r"embeddings.patch_embeddings.projection",
|
||||
|
||||
# Encoder Block
|
||||
r"network.encoder.backbone.blocks.(\d+).norm1" : r"layers.\1.norm1",
|
||||
r"network.encoder.backbone.blocks.(\d+).attn.proj" : r"layers.\1.attention.out_proj",
|
||||
r"network.encoder.backbone.blocks.(\d+).ls1.gamma" : r"layers.\1.layer_scale1.lambda1",
|
||||
r"network.encoder.backbone.blocks.(\d+).norm2" : r"layers.\1.norm2",
|
||||
r"network.encoder.backbone.blocks.(\d+).ls2.gamma" : r"layers.\1.layer_scale2.lambda1",
|
||||
r"network.encoder.backbone.blocks.(\d+).attn" : r"layers.\1.attention",
|
||||
|
||||
# Others
|
||||
r"network.q.weight" : r"query.weight",
|
||||
r"network.class_head" : r"class_predictor",
|
||||
r"network.upscale.(\d+).conv1" : r"upscale_block.block.\1.conv1",
|
||||
r"network.upscale.(\d+).conv2" : r"upscale_block.block.\1.conv2",
|
||||
r"network.upscale.(\d+).norm" : r"upscale_block.block.\1.layernorm2d",
|
||||
r"network.mask_head.0" : r"mask_head.fc1",
|
||||
r"network.mask_head.2" : r"mask_head.fc2",
|
||||
r"network.mask_head.4" : r"mask_head.fc3",
|
||||
r"network.encoder.backbone.norm" : r"layernorm",
|
||||
r"network.attn_mask_probs" : r"attn_mask_probs",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
# Mappings for MLP layers, depending on the type of MLP used in ckpts.
|
||||
MLP_MAPPINGS = {
|
||||
"swiglu_ffn": {
|
||||
r"network.encoder.backbone.blocks.(\d+).mlp.fc1": r"layers.\1.mlp.weights_in",
|
||||
r"network.encoder.backbone.blocks.(\d+).mlp.fc2": r"layers.\1.mlp.weights_out",
|
||||
},
|
||||
"vanilla_mlp": {
|
||||
r"network.encoder.backbone.blocks.(\d+).mlp": r"layers.\1.mlp",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict):
|
||||
keys_as_text = "\n".join(state_dict.keys())
|
||||
new_keys_as_text = keys_as_text
|
||||
for old, repl in MAPPINGS.items():
|
||||
if repl is None:
|
||||
new_keys_as_text = re.sub(old, "", new_keys_as_text)
|
||||
else:
|
||||
new_keys_as_text = re.sub(old, repl, new_keys_as_text)
|
||||
output_dict = dict(zip(keys_as_text.split("\n"), new_keys_as_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def split_qkv_tensor(key, tensor):
|
||||
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
|
||||
|
||||
new_keys = ["q_proj", "k_proj", "v_proj"]
|
||||
split_size = tensor.shape[0] // 3
|
||||
split_tensors = torch.split(tensor, split_size, dim=0)
|
||||
|
||||
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
"""Convert state dict keys to HF format."""
|
||||
conversion_dict = convert_old_keys_to_new_keys(state_dict)
|
||||
converted_state_dict = {}
|
||||
|
||||
for old_key, new_key in conversion_dict.items():
|
||||
if new_key:
|
||||
if "qkv" in new_key: # Detect merged attention keys and split them.
|
||||
qkv_split_dict = split_qkv_tensor(new_key, state_dict[old_key])
|
||||
converted_state_dict.update(qkv_split_dict)
|
||||
else:
|
||||
converted_state_dict[new_key] = state_dict[old_key]
|
||||
|
||||
for i in [
|
||||
"network.encoder.pixel_mean",
|
||||
"network.encoder.pixel_std",
|
||||
]:
|
||||
converted_state_dict.pop(i)
|
||||
|
||||
# Embeddings will not have initial dimension
|
||||
pos_embed_key = "embeddings.position_embeddings.weight"
|
||||
converted_state_dict[pos_embed_key] = converted_state_dict[pos_embed_key].squeeze(0)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def ensure_model_downloaded(
|
||||
repo_id: Optional[str] = None, revision: Optional[str] = None, local_dir: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Ensures model files are downloaded locally, downloads them if not.
|
||||
Returns path to local files.
|
||||
|
||||
Args:
|
||||
repo_id: The Hugging Face model repo ID (required if local_dir not provided)
|
||||
revision: Optional git revision to use
|
||||
local_dir: Optional local directory path where model files should be stored/found
|
||||
"""
|
||||
if local_dir is not None:
|
||||
if os.path.exists(local_dir):
|
||||
print(f"Using provided local directory: {local_dir}")
|
||||
else:
|
||||
# Create the local directory if it doesn't exist
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
print(f"Created local directory: {local_dir}")
|
||||
|
||||
if repo_id is None:
|
||||
raise ValueError("Either repo_id or local_dir must be provided")
|
||||
|
||||
print(f"Ensuring {repo_id} (revision: {revision or 'latest'}) is downloaded...")
|
||||
|
||||
try:
|
||||
# First try to find files locally
|
||||
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=True, local_dir=local_dir)
|
||||
print(f"Found model files locally at {download_dir}")
|
||||
return download_dir
|
||||
except Exception:
|
||||
# If files not found locally, download them
|
||||
print(f"Downloading model files for {repo_id}...")
|
||||
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=False, local_dir=local_dir)
|
||||
print(f"Downloaded model files to {download_dir}")
|
||||
return download_dir
|
||||
|
||||
|
||||
def load_model_state_dict(input_path: str) -> dict:
|
||||
"""
|
||||
Load model state dict, handling both single and sharded files.
|
||||
"""
|
||||
index_path = os.path.join(input_path, "pytorch_model.bin.index.json")
|
||||
single_file_path = os.path.join(input_path, "pytorch_model.bin")
|
||||
|
||||
# Check if we have a sharded model
|
||||
if os.path.exists(index_path):
|
||||
print("Loading sharded model...")
|
||||
state_dict = {}
|
||||
with open(index_path, "r") as f:
|
||||
index = json.load(f)
|
||||
|
||||
# Get unique shard files and load each one only once
|
||||
unique_shard_files = sorted(set(index["weight_map"].values()))
|
||||
for shard_file in unique_shard_files:
|
||||
print(f"Loading shard {shard_file}...")
|
||||
shard_path = os.path.join(input_path, shard_file)
|
||||
shard_dict = torch.load(shard_path, map_location="cpu")
|
||||
state_dict.update(shard_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Single file model
|
||||
elif os.path.exists(single_file_path):
|
||||
print("Loading single file model...")
|
||||
return torch.load(single_file_path, map_location="cpu")
|
||||
|
||||
else:
|
||||
raise ValueError(f"No model files found in {input_path}")
|
||||
|
||||
|
||||
def convert_model(
|
||||
repo_id=None,
|
||||
local_dir=None,
|
||||
output_dir=None,
|
||||
output_hub_path=None,
|
||||
safe_serialization=True,
|
||||
revision=None,
|
||||
):
|
||||
"""Convert and save the model weights, processor, and configuration."""
|
||||
if output_dir is None and output_hub_path is None:
|
||||
raise ValueError("At least one of output_dir or output_hub_path must be specified")
|
||||
|
||||
if repo_id is None and local_dir is None:
|
||||
raise ValueError("Either repo_id or local_dir must be specified")
|
||||
|
||||
# Create output directory if specified
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
print(f"Created/verified output directory: {output_dir}")
|
||||
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
# Download or locate model files
|
||||
input_path = ensure_model_downloaded(repo_id=repo_id, revision=revision, local_dir=local_dir)
|
||||
|
||||
with open(os.path.join(input_path, "config.json"), "r") as f:
|
||||
config_data = json.load(f)
|
||||
# Pop off unwanted keys
|
||||
_ = config_data.pop("backbone", None)
|
||||
|
||||
config = EomtConfig(
|
||||
**{
|
||||
**config_data,
|
||||
"layerscale_value": 1e-5,
|
||||
}
|
||||
)
|
||||
|
||||
if "semantic" in repo_id.split("_"):
|
||||
size = {"shortest_edge": config.image_size, "longest_edge": None}
|
||||
do_split_image = True
|
||||
do_pad = False
|
||||
else:
|
||||
size = {"shortest_edge": config.image_size, "longest_edge": config.image_size}
|
||||
do_split_image = False
|
||||
do_pad = True
|
||||
|
||||
if "giant" in repo_id.split("_"):
|
||||
config.use_swiglu_ffn = True
|
||||
config.hidden_size = 1536
|
||||
config.num_hidden_layers = 40
|
||||
config.num_attention_heads = 24
|
||||
# Update MAPPINGS for ckpts depending on the MLP type
|
||||
MAPPINGS.update(MLP_MAPPINGS["swiglu_ffn"])
|
||||
else:
|
||||
MAPPINGS.update(MLP_MAPPINGS["vanilla_mlp"])
|
||||
|
||||
processor = EomtImageProcessorFast(size=size, do_split_image=do_split_image, do_pad=do_pad)
|
||||
|
||||
# Save the config and processor
|
||||
if output_dir:
|
||||
config.save_pretrained(output_dir)
|
||||
processor.save_pretrained(output_dir)
|
||||
if output_hub_path:
|
||||
config.push_to_hub(output_hub_path)
|
||||
processor.push_to_hub(output_hub_path)
|
||||
|
||||
# Initialize model with empty weights
|
||||
print("Creating empty model...")
|
||||
with init_empty_weights():
|
||||
model = EomtForUniversalSegmentation(config)
|
||||
|
||||
# Load and convert state dict
|
||||
print("Loading state dict...")
|
||||
state_dict = load_model_state_dict(input_path)
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
|
||||
# Load converted state dict
|
||||
print("Loading converted weights into model...")
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Save the model
|
||||
if output_dir:
|
||||
print(f"Saving model to {output_dir}...")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
if output_hub_path:
|
||||
print(f"Pushing model to hub at {output_hub_path}...")
|
||||
model.push_to_hub(output_hub_path, safe_serialization=safe_serialization)
|
||||
|
||||
del state_dict, model
|
||||
gc.collect()
|
||||
|
||||
# Validate the saved model if saved locally
|
||||
if output_dir:
|
||||
print("Reloading the local model to check if it's saved correctly...")
|
||||
EomtForUniversalSegmentation.from_pretrained(output_dir, device_map="auto")
|
||||
print("Local model reloaded successfully.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--hf_repo_id",
|
||||
help="HuggingFace Hub repo ID for the model",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_dir",
|
||||
help="Local directory containing the model files",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
help="Specific revision to download from the Hub",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model locally",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_hub_path",
|
||||
help="Repository ID to push model to hub (e.g. 'username/model-name')",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
help="Whether to save using safetensors",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_dir is None and args.output_hub_path is None:
|
||||
raise ValueError("At least one of --output_dir or --output_hub_path must be specified")
|
||||
|
||||
if args.hf_repo_id is None and args.local_dir is None:
|
||||
raise ValueError("Either --hf_repo_id or --local_dir must be specified")
|
||||
|
||||
convert_model(
|
||||
repo_id=args.hf_repo_id,
|
||||
local_dir=args.local_dir,
|
||||
output_dir=args.output_dir,
|
||||
output_hub_path=args.output_hub_path,
|
||||
safe_serialization=args.safe_serialization,
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
972
src/transformers/models/eomt/image_processing_eomt.py
Normal file
972
src/transformers/models/eomt/image_processing_eomt.py
Normal file
@@ -0,0 +1,972 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for EoMT."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
PaddingMode,
|
||||
pad,
|
||||
resize,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_flat_list_of_images,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
TensorType,
|
||||
filter_out_non_signature_kwargs,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Adapted from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks
|
||||
def convert_segmentation_map_to_binary_masks(
|
||||
segmentation_map: "np.ndarray",
|
||||
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
):
|
||||
if ignore_index is not None:
|
||||
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
||||
|
||||
# Get unique ids (class or instance ids based on input)
|
||||
all_labels = np.unique(segmentation_map)
|
||||
|
||||
# Drop background label if applicable
|
||||
if ignore_index is not None:
|
||||
all_labels = all_labels[all_labels != ignore_index]
|
||||
|
||||
# Generate a binary mask for each object instance
|
||||
binary_masks = [(segmentation_map == i) for i in all_labels]
|
||||
|
||||
# Stack the binary masks
|
||||
if binary_masks:
|
||||
binary_masks = np.stack(binary_masks, axis=0)
|
||||
else:
|
||||
binary_masks = np.zeros((0, *segmentation_map.shape))
|
||||
|
||||
# Convert instance ids to class ids
|
||||
if instance_id_to_semantic_id is not None:
|
||||
labels = np.zeros(all_labels.shape[0])
|
||||
|
||||
for label in all_labels:
|
||||
class_id = instance_id_to_semantic_id[label + 1 if ignore_index is not None else label]
|
||||
labels[all_labels == label] = class_id - 1 if ignore_index is not None else class_id
|
||||
else:
|
||||
labels = all_labels
|
||||
|
||||
return binary_masks.astype(np.float32), labels.astype(np.int64)
|
||||
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
|
||||
"""
|
||||
Computes the output image size given the input image size and the desired output size.
|
||||
|
||||
Args:
|
||||
image_size (`Tuple[int, int]`):
|
||||
The input image size.
|
||||
size (`int`):
|
||||
The desired output size.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum allowed output size.
|
||||
"""
|
||||
height, width = image_size
|
||||
raw_size = None
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((height, width)))
|
||||
max_original_size = float(max((height, width)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
raw_size = max_size * min_original_size / max_original_size
|
||||
size = int(round(raw_size))
|
||||
|
||||
if (height <= width and height == size) or (width <= height and width == size):
|
||||
oh, ow = height, width
|
||||
elif width < height:
|
||||
ow = size
|
||||
if max_size is not None and raw_size is not None:
|
||||
oh = round(raw_size * height / width)
|
||||
else:
|
||||
oh = round(size * height / width)
|
||||
else:
|
||||
oh = size
|
||||
if max_size is not None and raw_size is not None:
|
||||
ow = round(raw_size * width / height)
|
||||
else:
|
||||
ow = round(size * width / height)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
|
||||
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
||||
"""
|
||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
||||
`labels`.
|
||||
|
||||
Args:
|
||||
masks (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries, height, width)`.
|
||||
scores (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries)`.
|
||||
labels (`torch.Tensor`):
|
||||
A tensor of shape `(num_queries)`.
|
||||
object_mask_threshold (`float`):
|
||||
A number between 0 and 1 used to binarize the masks.
|
||||
Raises:
|
||||
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
|
||||
Returns:
|
||||
`tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
|
||||
< `object_mask_threshold`.
|
||||
"""
|
||||
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
|
||||
raise ValueError("mask, scores and labels must have the same shape!")
|
||||
|
||||
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
|
||||
|
||||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||
|
||||
|
||||
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
|
||||
# Get the mask associated with the k class
|
||||
mask_k = mask_labels == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_mask = mask_probs[k] >= mask_threshold
|
||||
original_area = original_mask.sum()
|
||||
|
||||
final_mask = mask_k & original_mask
|
||||
final_mask_area = final_mask.sum()
|
||||
|
||||
mask_exists = mask_k_area > 0 and original_area > 0 and final_mask_area > 0
|
||||
|
||||
if mask_exists:
|
||||
area_ratio = mask_k_area / original_area
|
||||
if not area_ratio.item() > overlap_mask_area_threshold:
|
||||
mask_exists = False
|
||||
|
||||
return mask_exists, final_mask
|
||||
|
||||
|
||||
def compute_segments(
|
||||
mask_probs,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
stuff_classes,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
target_size: Optional[tuple[int, int]] = None,
|
||||
):
|
||||
height = mask_probs.shape[1] if target_size is None else target_size[0]
|
||||
width = mask_probs.shape[2] if target_size is None else target_size[1]
|
||||
|
||||
segmentation = torch.zeros((height, width), dtype=torch.long, device=mask_probs.device) - 1
|
||||
segments: list[dict] = []
|
||||
|
||||
# Compute per-pixel assignment based on weighted mask scores
|
||||
mask_probs = mask_probs.sigmoid()
|
||||
mask_labels = (pred_scores[:, None, None] * mask_probs).argmax(0)
|
||||
|
||||
# Keep track of instances of each class
|
||||
current_segment_id = 0
|
||||
stuff_memory_list: dict[str, int] = {}
|
||||
|
||||
for k in range(pred_labels.shape[0]):
|
||||
pred_class = pred_labels[k].item()
|
||||
|
||||
# Check if mask exists and large enough to be a segment
|
||||
mask_exists, final_mask = check_segment_validity(
|
||||
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
|
||||
)
|
||||
|
||||
if not mask_exists:
|
||||
continue
|
||||
|
||||
if stuff_classes and pred_class in stuff_classes:
|
||||
if pred_class in stuff_memory_list:
|
||||
segmentation[final_mask] = stuff_memory_list[pred_class]
|
||||
continue
|
||||
else:
|
||||
stuff_memory_list[pred_class] = current_segment_id
|
||||
|
||||
segmentation[final_mask] = current_segment_id
|
||||
segment_score = round(pred_scores[k].item(), 6)
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_class,
|
||||
"score": segment_score,
|
||||
}
|
||||
)
|
||||
current_segment_id += 1
|
||||
return segmentation, segments
|
||||
|
||||
|
||||
def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]:
|
||||
"""Returns the height and width from a size dict."""
|
||||
target_height = size_dict["shortest_edge"]
|
||||
target_width = size_dict.get("longest_edge", None) or target_height
|
||||
|
||||
return target_height, target_width
|
||||
|
||||
|
||||
class EomtImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a EoMT image processor. The image processor can be used to prepare image(s) and optional targets
|
||||
for the model.
|
||||
|
||||
This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the input to a certain `size`.
|
||||
size (`int`, *optional*, defaults to 640):
|
||||
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
|
||||
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
|
||||
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
|
||||
height / width, size)`.
|
||||
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
||||
to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the input to a certain `scale`.
|
||||
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
|
||||
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
do_split_image (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the
|
||||
input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches.
|
||||
Otherwise, the input images will be padded to the target size.
|
||||
do_pad (`bool`, *optional*, defaults to `False`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
|
||||
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
|
||||
ImageNet std.
|
||||
ignore_index (`int`, *optional*):
|
||||
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
||||
denoted with 0 (background) will be replaced with `ignore_index`.
|
||||
num_labels (`int`, *optional*):
|
||||
The number of labels in the segmentation map.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
do_split_image: bool = False,
|
||||
do_pad: bool = False,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
num_labels: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
size = size if size is not None else {"shortest_edge": 640, "longest_edge": 640}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.do_split_image = do_split_image
|
||||
self.do_pad = do_pad
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||
self.ignore_index = ignore_index
|
||||
self.num_labels = num_labels
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: dict,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format=None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||
resized to keep the input aspect ratio.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
image_size = get_image_size(image)
|
||||
output_size = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||
|
||||
image = resize(
|
||||
image=image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
return_numpy=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def _split_image(self, image: ImageInput, size: dict, image_index: int) -> tuple[list, list]:
|
||||
"""Slices an image into overlapping patches for semantic segmentation."""
|
||||
|
||||
patches, patch_offsets = [], []
|
||||
|
||||
image_size = get_image_size(image)
|
||||
patch_size = size["shortest_edge"]
|
||||
|
||||
longer_side = max(image_size)
|
||||
num_patches = math.ceil(longer_side / patch_size)
|
||||
total_overlap = num_patches * patch_size - longer_side
|
||||
overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0
|
||||
|
||||
for i in range(num_patches):
|
||||
start = int(i * (patch_size - overlap_per_patch))
|
||||
end = start + patch_size
|
||||
|
||||
if image_size[0] > image_size[1]:
|
||||
patch = image[:, start:end, :]
|
||||
else:
|
||||
patch = image[:, :, start:end]
|
||||
|
||||
patches.append(patch)
|
||||
patch_offsets.append([image_index, start, end])
|
||||
|
||||
return patches, patch_offsets
|
||||
|
||||
def _pad(self, image: ImageInput, size: dict) -> np.ndarray:
|
||||
"""Pads the image to the target size using zero padding."""
|
||||
height, width = get_image_size(image)
|
||||
|
||||
target_height, target_width = get_target_size(size)
|
||||
pad_h = max(0, target_height - height)
|
||||
pad_w = max(0, target_width - width)
|
||||
|
||||
padding = ((0, pad_h), (0, pad_w))
|
||||
|
||||
# Channel axis is last; default padding format is compatible
|
||||
padded_image = pad(image=image, padding=padding, mode=PaddingMode.CONSTANT, constant_values=0.0)
|
||||
return padded_image
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_split_image: Optional[bool] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a batch of images."""
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(
|
||||
image,
|
||||
size=size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
processed_images, patch_offsets = [], []
|
||||
|
||||
if do_split_image:
|
||||
for idx, img in enumerate(images):
|
||||
patches, offsets = self._split_image(img, size, idx)
|
||||
processed_images.extend(patches)
|
||||
patch_offsets.extend(offsets)
|
||||
|
||||
images = processed_images
|
||||
|
||||
if do_pad:
|
||||
images = [self._pad(img, size) for img in images]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(img, scale=rescale_factor, input_data_format=input_data_format) for img in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(
|
||||
image,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
return images, patch_offsets
|
||||
|
||||
def _preprocess_mask(
|
||||
self,
|
||||
segmentation_map: ImageInput,
|
||||
do_resize: Optional[bool] = False,
|
||||
do_pad: Optional[bool] = False,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
data_format: Union[str, ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map)
|
||||
|
||||
if do_resize:
|
||||
segmentation_map = self.resize(
|
||||
segmentation_map,
|
||||
size=size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
if do_pad:
|
||||
segmentation_map = self._pad(segmentation_map, size)
|
||||
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
segmentation_map = segmentation_map.squeeze(0)
|
||||
return torch.from_numpy(segmentation_map)
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[Union[list[dict[int, int]], dict[int, int]]] = None,
|
||||
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
|
||||
do_split_image: Optional[bool] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocesses images or a batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image or batch of images to preprocess.
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
||||
A mapping between object instance ids and class ids.
|
||||
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
|
||||
Whether to split the input images into overlapping patches for semantic segmentation.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the input images.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the input images by `rescale_factor`.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Factor to scale image pixel values.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the input images.
|
||||
do_pad (`bool`, *optional*, defaults to `False`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean for normalization. Single value or list for each channel.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation for normalization. Single value or list for each channel.
|
||||
ignore_index (`int`, *optional*):
|
||||
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
||||
denoted with 0 (background) will be replaced with `ignore_index`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be `"pt"`, `"tf"`, `"np"`, or `"jax"`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
Channel format of the output image. Either `"channels_first"` or `"channels_last"`.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
Channel format of the input image.
|
||||
"""
|
||||
|
||||
do_split_image = do_split_image if do_split_image is not None else self.do_split_image
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
||||
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
pixel_values_list, patch_offsets = self._preprocess_images(
|
||||
images=images,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_split_image=do_split_image,
|
||||
do_pad=do_pad,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
||||
segmentation_maps = [to_numpy_array(mask) for mask in segmentation_maps]
|
||||
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(
|
||||
segmentation_map,
|
||||
do_resize=do_resize,
|
||||
do_pad=do_pad,
|
||||
size=size,
|
||||
resample=PILImageResampling.NEAREST,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
|
||||
encoded_inputs = self.encode_inputs(
|
||||
pixel_values_list,
|
||||
segmentation_maps,
|
||||
instance_id_to_semantic_id,
|
||||
ignore_index,
|
||||
return_tensors,
|
||||
input_data_format=data_format,
|
||||
)
|
||||
|
||||
if do_split_image and patch_offsets:
|
||||
encoded_inputs["patch_offsets"] = patch_offsets
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def encode_inputs(
|
||||
self,
|
||||
pixel_values_list: list[ImageInput],
|
||||
segmentation_maps: ImageInput = None,
|
||||
instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
||||
|
||||
EoMT addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
|
||||
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
|
||||
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
|
||||
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
|
||||
each mask.
|
||||
|
||||
Args:
|
||||
pixel_values_list (`List[ImageInput]`):
|
||||
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
||||
width)`.
|
||||
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
||||
|
||||
(`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
|
||||
|
||||
If left to the default, will return a pixel mask that is:
|
||||
|
||||
- 1 for pixels that are real (i.e. **not masked**),
|
||||
- 0 for pixels that are padding (i.e. **masked**).
|
||||
|
||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
||||
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
|
||||
instance segmentation map where each pixel represents an instance id. Can be provided as a single
|
||||
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
|
||||
instance ids in each image separately.
|
||||
|
||||
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
|
||||
(when `annotations` are provided).
|
||||
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
|
||||
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
|
||||
`mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
"""
|
||||
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
||||
|
||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
|
||||
|
||||
encoded_inputs = BatchFeature({"pixel_values": pixel_values_list}, tensor_type=return_tensors)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
mask_labels = []
|
||||
class_labels = []
|
||||
# Convert to list of binary masks and labels
|
||||
for idx, segmentation_map in enumerate(segmentation_maps):
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
if isinstance(instance_id_to_semantic_id, list):
|
||||
instance_id = instance_id_to_semantic_id[idx]
|
||||
else:
|
||||
instance_id = instance_id_to_semantic_id
|
||||
# Use instance2class_id mapping per image
|
||||
masks, classes = convert_segmentation_map_to_binary_masks(
|
||||
segmentation_map,
|
||||
instance_id,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
mask_labels.append(torch.from_numpy(masks))
|
||||
class_labels.append(torch.from_numpy(classes))
|
||||
|
||||
# we cannot batch them since they don't share a common class size
|
||||
encoded_inputs["mask_labels"] = mask_labels
|
||||
encoded_inputs["class_labels"] = class_labels
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def merge_image_patches(
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Reconstructs full-size semantic segmentation logits from patch predictions.
|
||||
|
||||
Args:
|
||||
segmentation_logits (`torch.Tensor`):
|
||||
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
||||
for each image patch.
|
||||
patch_offsets (`List[Tuple[int, int, int]]`):
|
||||
A list of tuples where each tuple contains:
|
||||
- `image_index` (int): Index of the original image this patch belongs to.
|
||||
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
||||
- `end` (int): End pixel index of the patch along the long dimension.
|
||||
original_image_sizes (`List[Tuple[int, int]]`):
|
||||
List of original (height, width) dimensions for each image before preprocessing.
|
||||
size (`Dict[str, int]`):
|
||||
A size dict which was used to resize.
|
||||
"""
|
||||
num_classes = segmentation_logits.shape[1]
|
||||
aggregated_logits = []
|
||||
patch_counts = []
|
||||
|
||||
for image_size in original_image_sizes:
|
||||
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
|
||||
# Stitch patches back into full-sized logit maps
|
||||
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
||||
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
|
||||
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
||||
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
||||
else:
|
||||
aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx]
|
||||
patch_counts[image_idx][:, :, patch_start:patch_end] += 1
|
||||
|
||||
# Normalize and resize logits to original image size
|
||||
reconstructed_logits = []
|
||||
for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)):
|
||||
averaged_logits = logit_sum / count.clamp(min=1)
|
||||
resized_logits = F.interpolate(
|
||||
averaged_logits[None, ...],
|
||||
size=original_image_sizes[idx],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0]
|
||||
|
||||
reconstructed_logits.append(resized_logits)
|
||||
|
||||
return reconstructed_logits
|
||||
|
||||
def unpad_image(
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Restores panoptic segmentation logits to their original image resolutions."""
|
||||
|
||||
resized_logits = []
|
||||
|
||||
for idx, original_size in enumerate(original_image_sizes):
|
||||
target_height, target_width = get_size_with_aspect_ratio(
|
||||
original_size, size["shortest_edge"], size["longest_edge"]
|
||||
)
|
||||
cropped_logits = segmentation_logits[idx][:, :target_height, :target_width]
|
||||
upsampled_logits = F.interpolate(
|
||||
cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
resized_logits.append(upsampled_logits)
|
||||
return resized_logits
|
||||
|
||||
def post_process_semantic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
size: Optional[dict[str, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Post-processes model outputs into final semantic segmentation prediction."""
|
||||
|
||||
size = size if size is not None else self.size
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = F.interpolate(
|
||||
masks_queries_logits,
|
||||
size=output_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Remove the null class `[..., :-1]`
|
||||
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
||||
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
||||
|
||||
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||
|
||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
|
||||
|
||||
preds = torch.stack(output_logits).argmax(dim=1)
|
||||
return preds
|
||||
|
||||
def post_process_panoptic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.8,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
stuff_classes: Optional[list[int]] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
):
|
||||
"""Post-processes model outputs into final panoptic segmentation prediction."""
|
||||
|
||||
size = size if size is not None else self.size
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_labels = class_queries_logits.shape[-1] - 1
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = F.interpolate(
|
||||
masks_queries_logits,
|
||||
size=output_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
||||
|
||||
results: list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_probs, pred_scores, pred_labels = remove_low_and_no_objects(
|
||||
mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels
|
||||
)
|
||||
|
||||
# No mask found
|
||||
if mask_probs.shape[0] <= 0:
|
||||
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs=mask_probs,
|
||||
pred_scores=pred_scores,
|
||||
pred_labels=pred_labels,
|
||||
stuff_classes=stuff_classes,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
|
||||
)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
def post_process_instance_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.5,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
):
|
||||
"""Post-processes model outputs into Instance Segmentation Predictions."""
|
||||
|
||||
size = size if size is not None else self.size
|
||||
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = F.interpolate(
|
||||
masks_queries_logits,
|
||||
size=output_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
|
||||
device = masks_queries_logits.device
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_queries = class_queries_logits.shape[-2]
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_pred = mask_probs_batch[i]
|
||||
mask_class = class_queries_logits[i]
|
||||
|
||||
# Remove the null class `[..., :-1]`
|
||||
scores, pred_classes = mask_class.softmax(dim=-1)[..., :-1].max(-1)
|
||||
pred_masks = (mask_pred > 0).float()
|
||||
|
||||
# Calculate average mask prob
|
||||
mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
|
||||
pred_masks.flatten(1).sum(1) + 1e-6
|
||||
)
|
||||
pred_scores = scores * mask_scores
|
||||
|
||||
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
|
||||
|
||||
instance_maps, segments = [], []
|
||||
current_segment_id = 0
|
||||
for j in range(num_queries):
|
||||
score = pred_scores[j].item()
|
||||
|
||||
if not torch.all(pred_masks[j] == 0) and score >= threshold:
|
||||
segmentation[pred_masks[j] == 1] = current_segment_id
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_classes[j].item(),
|
||||
"score": round(score, 6),
|
||||
}
|
||||
)
|
||||
current_segment_id += 1
|
||||
instance_maps.append(pred_masks[j])
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
|
||||
__all__ = ["EomtImageProcessor"]
|
||||
580
src/transformers/models/eomt/image_processing_eomt_fast.py
Normal file
580
src/transformers/models/eomt/image_processing_eomt_fast.py
Normal file
@@ -0,0 +1,580 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for EoMT."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
make_list_of_images,
|
||||
pil_torch_interpolation_mapping,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from .image_processing_eomt import (
|
||||
compute_segments,
|
||||
convert_segmentation_map_to_binary_masks,
|
||||
get_size_with_aspect_ratio,
|
||||
remove_low_and_no_objects,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class EomtImageProcessorFastKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
do_split_image (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the
|
||||
input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches.
|
||||
Otherwise, the input images will be padded to the target size.
|
||||
do_pad (`bool`, *optional*, defaults to `False`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
ignore_index (`int`, *optional*):
|
||||
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
||||
denoted with 0 (background) will be replaced with `ignore_index`.
|
||||
"""
|
||||
|
||||
do_split_image: bool
|
||||
do_pad: bool
|
||||
ignore_index: Optional[int] = None
|
||||
|
||||
|
||||
def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]:
|
||||
"""Returns the height and width from a size dict."""
|
||||
target_height = size_dict["shortest_edge"]
|
||||
target_width = size_dict["longest_edge"] or target_height
|
||||
|
||||
return target_height, target_width
|
||||
|
||||
|
||||
def reorder_patches_and_offsets(
|
||||
patches: list[torch.Tensor], offsets: list[list[int]]
|
||||
) -> tuple[list[torch.Tensor], list[list[int]]]:
|
||||
"""Sorts patches and offsets according to the original image index."""
|
||||
|
||||
combined = list(zip(offsets, patches))
|
||||
combined.sort(key=lambda x: x[0][0])
|
||||
sorted_offsets, sorted_patches = zip(*combined)
|
||||
|
||||
return list(sorted_patches), list(sorted_offsets)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_DEFAULT_MEAN
|
||||
image_std = IMAGENET_DEFAULT_STD
|
||||
size = {"shortest_edge": 640, "longest_edge": 640}
|
||||
default_to_square = False
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_split_image = False
|
||||
do_pad = False
|
||||
ignore_index = None
|
||||
valid_kwargs = EomtImageProcessorFastKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[EomtImageProcessorFastKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _split_image(self, images: torch.Tensor, size: dict, image_indices: int) -> tuple[list, list]:
|
||||
"""Slices an image into overlapping patches for semantic segmentation."""
|
||||
|
||||
patches, patch_offsets = [], []
|
||||
|
||||
_, _, height, width = images.shape
|
||||
patch_size = size["shortest_edge"]
|
||||
|
||||
longer_side = max(height, width)
|
||||
num_patches = math.ceil(longer_side / patch_size)
|
||||
total_overlap = num_patches * patch_size - longer_side
|
||||
overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0
|
||||
|
||||
for i in range(num_patches):
|
||||
start = int(i * (patch_size - overlap_per_patch))
|
||||
end = start + patch_size
|
||||
|
||||
if height > width:
|
||||
batch_patch = images[:, :, start:end, :]
|
||||
else:
|
||||
batch_patch = images[:, :, :, start:end]
|
||||
|
||||
for batch_idx, single in enumerate(torch.unbind(batch_patch, dim=0)):
|
||||
patches.append(single)
|
||||
patch_offsets.append([image_indices[batch_idx], start, end])
|
||||
|
||||
return patches, patch_offsets
|
||||
|
||||
def _pad(self, images: torch.Tensor, size: dict) -> torch.Tensor:
|
||||
"""Pads the image to the target size using zero padding."""
|
||||
_, _, height, width = images.shape
|
||||
|
||||
target_height, target_width = get_target_size(size)
|
||||
pad_h = max(0, target_height - height)
|
||||
pad_w = max(0, target_width - width)
|
||||
padding = (0, pad_w, 0, pad_h)
|
||||
|
||||
padded_images = torch.nn.functional.pad(images, padding, mode="constant", value=0.0)
|
||||
return padded_images
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
do_split_image: bool,
|
||||
do_pad: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses the input images and masks if provided."""
|
||||
processed_images, patch_offsets = [], []
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for batched resizing, Needed in case do_resize is False.
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
original_indices = [
|
||||
original_idx for original_idx, (img_shape, _) in grouped_images_index.items() if img_shape == shape
|
||||
]
|
||||
|
||||
if do_split_image:
|
||||
patches, offsets = self._split_image(stacked_images, size, original_indices)
|
||||
processed_images.extend(patches)
|
||||
patch_offsets.extend(offsets)
|
||||
|
||||
if do_pad:
|
||||
stacked_images = self._pad(stacked_images, size)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
if do_split_image:
|
||||
images, patch_offsets = reorder_patches_and_offsets(processed_images, patch_offsets)
|
||||
|
||||
if do_pad:
|
||||
images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
processed_images = torch.stack(images, dim=0) if return_tensors else images
|
||||
|
||||
return processed_images, patch_offsets
|
||||
|
||||
def _preprocess_images(self, images, **kwargs):
|
||||
"""Preprocesses the input images."""
|
||||
return self._preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess_masks(self, segmentation_maps: list[torch.Tensor], **kwargs):
|
||||
"""Preprocesses segmentation maps."""
|
||||
processed_segmentation_maps = []
|
||||
for segmentation_map in segmentation_maps:
|
||||
segmentation_map = self._process_image(
|
||||
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
|
||||
if segmentation_map.ndim == 2:
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
|
||||
processed_segmentation_maps.append(segmentation_map)
|
||||
|
||||
kwargs["do_normalize"] = False
|
||||
kwargs["do_rescale"] = False
|
||||
kwargs["input_data_format"] = ChannelDimension.FIRST
|
||||
|
||||
# Nearest interpolation is used for segmentation maps instead of BILINEAR.
|
||||
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
|
||||
|
||||
processed_segmentation_maps, _ = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
||||
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
|
||||
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
||||
|
||||
return processed_segmentation_maps
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[list[torch.Tensor]] = None,
|
||||
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
|
||||
**kwargs: Unpack[EomtImageProcessorFastKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The segmentation maps to preprocess for corresponding images.
|
||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
||||
A mapping between object instance ids and class ids.
|
||||
"""
|
||||
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self._valid_kwargs_names:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
images = self._prepare_input_images(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
# Prepare segmentation maps
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
resample = kwargs.pop("resample")
|
||||
|
||||
# Check if resample is an int before checking if it's an instance of PILImageResampling
|
||||
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
|
||||
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample
|
||||
)
|
||||
|
||||
# Pop kwargs that are not needed in _preprocess
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
ignore_index = kwargs.pop("ignore_index", None)
|
||||
|
||||
processed_images, patch_offsets = self._preprocess_images(images=images, **kwargs)
|
||||
|
||||
outputs = BatchFeature({"pixel_values": processed_images})
|
||||
|
||||
mask_labels, class_labels = [], []
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = self._preprocess_masks(segmentation_maps=segmentation_maps, **kwargs)
|
||||
# Convert to list of binary masks and labels
|
||||
for idx, segmentation_map in enumerate(segmentation_maps):
|
||||
if isinstance(instance_id_to_semantic_id, list):
|
||||
instance_id = instance_id_to_semantic_id[idx]
|
||||
else:
|
||||
instance_id = instance_id_to_semantic_id
|
||||
# Use instance2class_id mapping per image
|
||||
masks, classes = convert_segmentation_map_to_binary_masks(
|
||||
segmentation_map,
|
||||
instance_id,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
mask_labels.append(torch.from_numpy(masks))
|
||||
class_labels.append(torch.from_numpy(classes))
|
||||
|
||||
# we cannot batch them since they don't share a common class size
|
||||
outputs["mask_labels"] = mask_labels
|
||||
outputs["class_labels"] = class_labels
|
||||
|
||||
if patch_offsets:
|
||||
outputs["patch_offsets"] = patch_offsets
|
||||
|
||||
return outputs
|
||||
|
||||
def merge_image_patches(
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Reconstructs full-size semantic segmentation logits from patch predictions.
|
||||
|
||||
Args:
|
||||
segmentation_logits (`torch.Tensor`):
|
||||
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
||||
for each image patch.
|
||||
patch_offsets (`List[Tuple[int, int, int]]`):
|
||||
A list of tuples where each tuple contains:
|
||||
- `image_index` (int): Index of the original image this patch belongs to.
|
||||
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
||||
- `end` (int): End pixel index of the patch along the long dimension.
|
||||
original_image_sizes (`List[Tuple[int, int]]`):
|
||||
List of original (height, width) dimensions for each image before preprocessing.
|
||||
size (`Dict[str, int]`):
|
||||
A size dict which was used to resize.
|
||||
"""
|
||||
num_classes = segmentation_logits.shape[1]
|
||||
aggregated_logits = []
|
||||
patch_counts = []
|
||||
|
||||
for image_size in original_image_sizes:
|
||||
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
|
||||
# Stitch patches back into full-sized logit maps
|
||||
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
||||
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
|
||||
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
||||
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
||||
else:
|
||||
aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx]
|
||||
patch_counts[image_idx][:, :, patch_start:patch_end] += 1
|
||||
|
||||
# Normalize and resize logits to original image size
|
||||
reconstructed_logits = []
|
||||
for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)):
|
||||
averaged_logits = logit_sum / count.clamp(min=1)
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
averaged_logits[None, ...],
|
||||
size=original_image_sizes[idx],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0]
|
||||
|
||||
reconstructed_logits.append(resized_logits)
|
||||
|
||||
return reconstructed_logits
|
||||
|
||||
def unpad_image(
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Restores panoptic segmentation logits to their original image resolutions."""
|
||||
|
||||
resized_logits = []
|
||||
|
||||
for idx, original_size in enumerate(original_image_sizes):
|
||||
target_height, target_width = get_size_with_aspect_ratio(
|
||||
original_size, size["shortest_edge"], size["longest_edge"]
|
||||
)
|
||||
cropped_logits = segmentation_logits[idx][:, :target_height, :target_width]
|
||||
upsampled_logits = torch.nn.functional.interpolate(
|
||||
cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
resized_logits.append(upsampled_logits)
|
||||
return resized_logits
|
||||
|
||||
def post_process_semantic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
size: Optional[dict[str, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Post-processes model outputs into final semantic segmentation prediction."""
|
||||
|
||||
size = size if size is not None else self.size
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = torch.nn.functional.interpolate(
|
||||
masks_queries_logits,
|
||||
size=output_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Remove the null class `[..., :-1]`
|
||||
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
||||
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
||||
|
||||
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||
|
||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
|
||||
|
||||
preds = torch.stack(output_logits).argmax(dim=1)
|
||||
return preds
|
||||
|
||||
def post_process_panoptic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.8,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
stuff_classes: Optional[list[int]] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
):
|
||||
"""Post-processes model outputs into final panoptic segmentation prediction."""
|
||||
|
||||
size = size if size is not None else self.size
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_labels = class_queries_logits.shape[-1] - 1
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = torch.nn.functional.interpolate(
|
||||
masks_queries_logits,
|
||||
size=output_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
||||
|
||||
results: list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_probs, pred_scores, pred_labels = remove_low_and_no_objects(
|
||||
mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels
|
||||
)
|
||||
|
||||
# No mask found
|
||||
if mask_probs.shape[0] <= 0:
|
||||
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
|
||||
segmentation, segments = compute_segments(
|
||||
mask_probs=mask_probs,
|
||||
pred_scores=pred_scores,
|
||||
pred_labels=pred_labels,
|
||||
stuff_classes=stuff_classes,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
|
||||
)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
def post_process_instance_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.8,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
):
|
||||
"""Post-processes model outputs into Instance Segmentation Predictions."""
|
||||
|
||||
size = size if size is not None else self.size
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = torch.nn.functional.interpolate(
|
||||
masks_queries_logits,
|
||||
size=output_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
|
||||
device = masks_queries_logits.device
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
num_queries = class_queries_logits.shape[-2]
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(batch_size):
|
||||
mask_pred = mask_probs_batch[i]
|
||||
mask_class = class_queries_logits[i]
|
||||
|
||||
# Remove the null class `[..., :-1]`
|
||||
scores, pred_classes = mask_class.softmax(dim=-1)[..., :-1].max(-1)
|
||||
pred_masks = (mask_pred > 0).float()
|
||||
|
||||
# Calculate average mask prob
|
||||
mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
|
||||
pred_masks.flatten(1).sum(1) + 1e-6
|
||||
)
|
||||
pred_scores = scores * mask_scores
|
||||
|
||||
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
|
||||
|
||||
instance_maps, segments = [], []
|
||||
current_segment_id = 0
|
||||
for j in range(num_queries):
|
||||
score = pred_scores[j].item()
|
||||
|
||||
if not torch.all(pred_masks[j] == 0) and score >= threshold:
|
||||
segmentation[pred_masks[j] == 1] = current_segment_id
|
||||
segments.append(
|
||||
{
|
||||
"id": current_segment_id,
|
||||
"label_id": pred_classes[j].item(),
|
||||
"score": round(score, 6),
|
||||
}
|
||||
)
|
||||
current_segment_id += 1
|
||||
instance_maps.append(pred_masks[j])
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
|
||||
__all__ = ["EomtImageProcessorFast"]
|
||||
1242
src/transformers/models/eomt/modeling_eomt.py
Normal file
1242
src/transformers/models/eomt/modeling_eomt.py
Normal file
File diff suppressed because it is too large
Load Diff
588
src/transformers/models/eomt/modular_eomt.py
Normal file
588
src/transformers/models/eomt/modular_eomt.py
Normal file
@@ -0,0 +1,588 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch EoMT model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ..dinov2.modeling_dinov2 import (
|
||||
Dinov2Embeddings,
|
||||
Dinov2Layer,
|
||||
Dinov2LayerScale,
|
||||
Dinov2PatchEmbeddings,
|
||||
)
|
||||
from ..mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentation, Mask2FormerLoss
|
||||
from ..siglip.modeling_siglip import SiglipAttention
|
||||
from ..vit.configuration_vit import ViTConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class EomtConfig(ViTConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the EoMT
|
||||
[tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640)
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the hidden representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads in each attention layer.
|
||||
mlp_ratio (`int`, *optional*, defaults to 4):
|
||||
Ratio of the MLP hidden dimensionality to the hidden size.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings and encoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
image_size (`int`, *optional*, defaults to 640):
|
||||
The size (resolution) of each input image.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
layerscale_value (`float`, *optional*, defaults to 1.0):
|
||||
Initial value for the LayerScale parameter.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
The stochastic depth rate (drop path) used during training.
|
||||
num_upscale_blocks (`int`, *optional*, defaults to 2):
|
||||
Number of upsampling blocks used in the decoder or segmentation head.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability applied after attention projection.
|
||||
use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the SwiGLU feedforward neural network.
|
||||
num_blocks (`int`, *optional*, defaults to 4):
|
||||
Number of feature blocks or stages in the architecture.
|
||||
no_object_weight (`float`, *optional*, defaults to 0.1):
|
||||
Loss weight for the 'no object' class in panoptic/instance segmentation.
|
||||
class_weight (`float`, *optional*, defaults to 2.0):
|
||||
Loss weight for classification targets.
|
||||
mask_weight (`float`, *optional*, defaults to 5.0):
|
||||
Loss weight for mask prediction.
|
||||
dice_weight (`float`, *optional*, defaults to 5.0):
|
||||
Loss weight for the dice loss component.
|
||||
train_num_points (`int`, *optional*, defaults to 12544):
|
||||
Number of points to sample for mask loss computation during training.
|
||||
oversample_ratio (`float`, *optional*, defaults to 3.0):
|
||||
Oversampling ratio used in point sampling for mask training.
|
||||
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
|
||||
Ratio of points to sample based on importance during training.
|
||||
num_queries (`int`, *optional*, defaults to 200):
|
||||
Number of object queries in the Transformer.
|
||||
num_register_tokens (`int`, *optional*, defaults to 4):
|
||||
Number of learnable register tokens added to the transformer input.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import EomtConfig, EomtForUniversalSegmentation
|
||||
|
||||
>>> # Initialize configuration
|
||||
>>> config = EomtConfig()
|
||||
|
||||
>>> # Initialize model
|
||||
>>> model = EomtForUniversalSegmentation(config)
|
||||
|
||||
>>> # Access config
|
||||
>>> config = model.config
|
||||
```"""
|
||||
|
||||
model_type = "eomt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
mlp_ratio=4,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-6,
|
||||
image_size=640,
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
layerscale_value=1.0,
|
||||
drop_path_rate=0.0,
|
||||
num_upscale_blocks=2,
|
||||
attention_dropout=0.0,
|
||||
use_swiglu_ffn=False,
|
||||
num_blocks=4,
|
||||
no_object_weight: float = 0.1,
|
||||
class_weight: float = 2.0,
|
||||
mask_weight: float = 5.0,
|
||||
dice_weight: float = 5.0,
|
||||
train_num_points: int = 12544,
|
||||
oversample_ratio: float = 3.0,
|
||||
importance_sample_ratio: float = 0.75,
|
||||
num_queries=200,
|
||||
num_register_tokens=4,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
initializer_range=initializer_range,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
num_channels=num_channels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
del self.intermediate_size
|
||||
del self.qkv_bias
|
||||
del self.pooler_act
|
||||
del self.pooler_output_size
|
||||
del self.encoder_stride
|
||||
del self.attention_probs_dropout_prob
|
||||
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layerscale_value = layerscale_value
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_upscale_blocks = num_upscale_blocks
|
||||
self.use_swiglu_ffn = use_swiglu_ffn
|
||||
self.num_blocks = num_blocks
|
||||
self.no_object_weight = no_object_weight
|
||||
self.class_weight = class_weight
|
||||
self.mask_weight = mask_weight
|
||||
self.dice_weight = dice_weight
|
||||
self.train_num_points = train_num_points
|
||||
self.oversample_ratio = oversample_ratio
|
||||
self.importance_sample_ratio = importance_sample_ratio
|
||||
self.num_queries = num_queries
|
||||
self.num_register_tokens = num_register_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Class for outputs of [`EomtForUniversalSegmentationOutput`].
|
||||
|
||||
This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or
|
||||
[`~EomtImageProcessor.post_process_instance_segmentation`] or
|
||||
[`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
|
||||
[`~EomtImageProcessor] for details regarding usage.
|
||||
"""
|
||||
)
|
||||
class EomtForUniversalSegmentationOutput(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.Tensor`, *optional*):
|
||||
The computed loss, returned when labels are present.
|
||||
class_queries_logits (`torch.FloatTensor`):
|
||||
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
|
||||
query. Note the `+ 1` is needed because we incorporate the null class.
|
||||
masks_queries_logits (`torch.FloatTensor`):
|
||||
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
|
||||
query.
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Last hidden states (final feature map) of the last layer.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
class_queries_logits: Optional[torch.FloatTensor] = None
|
||||
masks_queries_logits: Optional[torch.FloatTensor] = None
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class EomtLoss(Mask2FormerLoss):
|
||||
pass
|
||||
|
||||
|
||||
class EomtPatchEmbeddings(Dinov2PatchEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class EomtEmbeddings(Dinov2Embeddings, nn.Module):
|
||||
def __init__(self, config: EomtConfig) -> None:
|
||||
Dinov2Embeddings().__init__()
|
||||
|
||||
self.config = config
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
|
||||
|
||||
self.patch_embeddings = EomtPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
|
||||
self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
|
||||
self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
|
||||
|
||||
def interpolate_pos_encoding(self):
|
||||
raise AttributeError("Not needed for Eomt Model")
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, _, _, _ = pixel_values.shape
|
||||
target_dtype = self.patch_embeddings.projection.weight.dtype
|
||||
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||
|
||||
embeddings = embeddings + self.position_embeddings(self.position_ids)
|
||||
embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class EomtAttention(SiglipAttention):
|
||||
pass
|
||||
|
||||
|
||||
class EomtLayerScale(Dinov2LayerScale):
|
||||
pass
|
||||
|
||||
|
||||
class EomtLayer(Dinov2Layer):
|
||||
pass
|
||||
|
||||
|
||||
class EomtLayerNorm2d(nn.LayerNorm):
|
||||
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||||
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = hidden_state.permute(0, 2, 3, 1)
|
||||
hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
hidden_state = hidden_state.permute(0, 3, 1, 2)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class EomtScaleLayer(nn.Module):
|
||||
def __init__(self, config: EomtConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.conv2 = nn.Conv2d(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.layernorm2d = EomtLayerNorm2d(hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.layernorm2d(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EomtScaleBlock(nn.Module):
|
||||
def __init__(self, config: EomtConfig):
|
||||
super().__init__()
|
||||
self.num_blocks = config.num_upscale_blocks
|
||||
self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)])
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for block in self.block:
|
||||
hidden_states = block(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EomtMaskHead(nn.Module):
|
||||
def __init__(self, config: EomtConfig):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
self.fc1 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.activation(self.fc1(hidden_states))
|
||||
hidden_states = self.activation(self.fc2(hidden_states))
|
||||
hidden_states = self.fc3(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class EomtPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = EomtConfig
|
||||
base_model_prefix = "eomt"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["EomtMLP"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
||||
if module.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(module.bias, -bound, bound)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=1)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, EomtLayerScale):
|
||||
if hasattr(module, "lambda1"):
|
||||
module.lambda1.data.fill_(self.config.layerscale_value)
|
||||
elif isinstance(module, EomtEmbeddings):
|
||||
module.cls_token.data = nn.init.trunc_normal_(
|
||||
module.cls_token.data.to(torch.float32), mean=0.0, std=std
|
||||
).to(module.cls_token.dtype)
|
||||
module.register_tokens.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The EoMT Model with head on top for instance/semantic/panoptic segmentation.
|
||||
"""
|
||||
)
|
||||
class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Module):
|
||||
def __init__(self, config: EomtConfig) -> None:
|
||||
nn.Module().__init__(config)
|
||||
self.config = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embeddings = EomtEmbeddings(config)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.query = nn.Embedding(config.num_queries, config.hidden_size)
|
||||
self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
self.upscale_block = EomtScaleBlock(config)
|
||||
self.mask_head = EomtMaskHead(config)
|
||||
|
||||
self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
|
||||
|
||||
self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
|
||||
self.weight_dict: dict[str, float] = {
|
||||
"loss_cross_entropy": config.class_weight,
|
||||
"loss_mask": config.mask_weight,
|
||||
"loss_dice": config.dice_weight,
|
||||
}
|
||||
|
||||
self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict)
|
||||
|
||||
self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def get_auxiliary_logits(self):
|
||||
raise AttributeError("Note needed for Eomt Model.")
|
||||
|
||||
def predict(self, logits: torch.Tensor):
|
||||
query_tokens = logits[:, : self.config.num_queries, :]
|
||||
class_logits = self.class_predictor(query_tokens)
|
||||
|
||||
prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
|
||||
prefix_tokens = prefix_tokens.transpose(1, 2)
|
||||
|
||||
prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
|
||||
|
||||
query_tokens = self.mask_head(query_tokens)
|
||||
prefix_tokens = self.upscale_block(prefix_tokens)
|
||||
|
||||
mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
|
||||
|
||||
return mask_logits, class_logits
|
||||
|
||||
@staticmethod
|
||||
def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
|
||||
if prob < 1:
|
||||
# Generate random queries to disable based on the probs
|
||||
random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
|
||||
|
||||
# Disable attention to the query tokens, considering the prefix tokens
|
||||
attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
|
||||
|
||||
return attn_mask
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Tensor,
|
||||
mask_labels: Optional[list[Tensor]] = None,
|
||||
class_labels: Optional[list[Tensor]] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
mask_labels (`List[torch.Tensor]`, *optional*):
|
||||
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||
class_labels (`List[torch.LongTensor]`, *optional*):
|
||||
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
||||
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
|
||||
attention_mask = None
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
|
||||
for idx, layer_module in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if idx == self.num_hidden_layers - self.config.num_blocks:
|
||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
|
||||
hidden_states = torch.cat((query, hidden_states), dim=1)
|
||||
|
||||
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
||||
self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
|
||||
):
|
||||
norm_hidden_states = self.layernorm(hidden_states)
|
||||
masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
|
||||
|
||||
masks_queries_logits_per_layer += (masks_queries_logits,)
|
||||
class_queries_logits_per_layer += (class_queries_logits,)
|
||||
|
||||
attention_mask = torch.ones(
|
||||
hidden_states.shape[0],
|
||||
hidden_states.shape[1],
|
||||
hidden_states.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
|
||||
interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
|
||||
interpolated_logits = interpolated_logits.view(
|
||||
interpolated_logits.size(0), interpolated_logits.size(1), -1
|
||||
)
|
||||
|
||||
num_query_tokens = self.config.num_queries
|
||||
encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens
|
||||
|
||||
# Set attention mask for queries to focus on encoder tokens based on interpolated logits
|
||||
attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
|
||||
|
||||
# Disable attention mask for random query tokens.
|
||||
attention_mask = self._disable_attention_mask(
|
||||
attention_mask,
|
||||
prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
|
||||
num_query_tokens=num_query_tokens,
|
||||
encoder_start_tokens=encoder_start_tokens,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Expand attention mask to 4d mask.
|
||||
attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
|
||||
attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
sequence_output = self.layernorm(hidden_states)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (sequence_output,)
|
||||
|
||||
masks_queries_logits, class_queries_logits = self.predict(sequence_output)
|
||||
masks_queries_logits_per_layer += (masks_queries_logits,)
|
||||
class_queries_logits_per_layer += (class_queries_logits,)
|
||||
|
||||
loss = None
|
||||
if mask_labels is not None and class_labels is not None:
|
||||
loss = 0.0
|
||||
for masks_queries_logits, class_queries_logits in zip(
|
||||
masks_queries_logits_per_layer, class_queries_logits_per_layer
|
||||
):
|
||||
loss_dict = self.get_loss_dict(
|
||||
masks_queries_logits=masks_queries_logits,
|
||||
class_queries_logits=class_queries_logits,
|
||||
mask_labels=mask_labels,
|
||||
class_labels=class_labels,
|
||||
auxiliary_predictions=None,
|
||||
)
|
||||
loss += self.get_loss(loss_dict)
|
||||
|
||||
return EomtForUniversalSegmentationOutput(
|
||||
loss=loss,
|
||||
masks_queries_logits=masks_queries_logits,
|
||||
class_queries_logits=class_queries_logits,
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["EomtConfig", "EomtPreTrainedModel", "EomtForUniversalSegmentation"]
|
||||
@@ -512,7 +512,7 @@ class Mask2FormerLoss(nn.Module):
|
||||
self.importance_sample_ratio = config.importance_sample_ratio
|
||||
|
||||
self.matcher = Mask2FormerHungarianMatcher(
|
||||
cost_class=1.0,
|
||||
cost_class=config.class_weight,
|
||||
cost_dice=config.dice_weight,
|
||||
cost_mask=config.mask_weight,
|
||||
num_points=self.num_points,
|
||||
|
||||
0
tests/models/eomt/__init__.py
Normal file
0
tests/models/eomt/__init__.py
Normal file
308
tests/models/eomt/test_image_processing_eomt.py
Normal file
308
tests/models/eomt/test_image_processing_eomt.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch EoMT Image Processor."""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import EomtImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import EomtImageProcessorFast
|
||||
from transformers.models.eomt.modeling_eomt import EomtForUniversalSegmentationOutput
|
||||
|
||||
|
||||
class EomtImageProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
size=None,
|
||||
do_resize=True,
|
||||
do_pad=True,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
num_labels=10,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.do_pad = do_pad
|
||||
self.size = size if size is not None else {"shortest_edge": 18, "longest_edge": 18}
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
# for the post_process_functions
|
||||
self.batch_size = 2
|
||||
self.num_queries = 3
|
||||
self.num_classes = 2
|
||||
self.height = 18
|
||||
self.width = 18
|
||||
self.num_labels = num_labels
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_pad": self.do_pad,
|
||||
"num_labels": self.num_labels,
|
||||
}
|
||||
|
||||
def prepare_fake_eomt_outputs(self, batch_size):
|
||||
return EomtForUniversalSegmentationOutput(
|
||||
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
|
||||
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
|
||||
)
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
def prepare_semantic_single_inputs():
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
example = ds[0]
|
||||
return example["image"], example["map"]
|
||||
|
||||
|
||||
def prepare_semantic_batch_inputs():
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
return list(ds["image"][:2]), list(ds["map"][:2])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = EomtImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = EomtImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = EomtImageProcessingTester(self)
|
||||
self.model_id = "tue-mps/coco_panoptic_eomt_large_640"
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 18, "longest_edge": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (2, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
@unittest.skip(reason="Not supported")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test Non batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (2, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (2, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image, dummy_map = prepare_semantic_single_inputs()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
# Lets check whether 99.9% of mask_labels values match or not.
|
||||
match_ratio = (image_encoding_slow.mask_labels[0] == image_encoding_fast.mask_labels[0]).float().mean().item()
|
||||
self.assertGreaterEqual(match_ratio, 0.999, "Mask labels do not match between slow and fast image processor.")
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
|
||||
for idx in range(len(dummy_maps)):
|
||||
match_ratio = (encoding_slow.mask_labels[idx] == encoding_fast.mask_labels[idx]).float().mean().item()
|
||||
self.assertGreaterEqual(
|
||||
match_ratio, 0.999, "Mask labels do not match between slow and fast image processors."
|
||||
)
|
||||
|
||||
def test_post_process_semantic_segmentation(self):
|
||||
processor = self.image_processing_class(**self.image_processor_dict)
|
||||
# Set longest_edge to None to test for semantic segmentatiom.
|
||||
processor.size = {"shortest_edge": 18, "longest_edge": None}
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, do_split_image=True, return_tensors="pt")
|
||||
patch_offsets = inputs.pop("patch_offsets")
|
||||
|
||||
original_sizes = [image.size[::-1]]
|
||||
|
||||
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
|
||||
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0])
|
||||
segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes)
|
||||
|
||||
self.assertEqual(segmentation[0].shape, (image.height, image.width))
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
processor = self.image_processing_class(**self.image_processor_dict)
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
original_sizes = [image.size[::-1], image.size[::-1]]
|
||||
|
||||
# lets test for batched input of 2
|
||||
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(2)
|
||||
segmentation = processor.post_process_panoptic_segmentation(outputs, original_sizes)
|
||||
|
||||
self.assertTrue(len(segmentation) == 2)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(el["segmentation"].shape, (image.height, image.width))
|
||||
|
||||
def test_post_process_instance_segmentation(self):
|
||||
processor = self.image_processing_class(**self.image_processor_dict)
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
original_sizes = [image.size[::-1], image.size[::-1]]
|
||||
|
||||
# lets test for batched input of 2
|
||||
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(2)
|
||||
segmentation = processor.post_process_instance_segmentation(outputs, original_sizes)
|
||||
|
||||
self.assertTrue(len(segmentation) == 2)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(el["segmentation"].shape, (image.height, image.width))
|
||||
475
tests/models/eomt/test_modeling_eomt.py
Normal file
475
tests/models/eomt/test_modeling_eomt.py
Normal file
@@ -0,0 +1,475 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch EoMT model."""
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation
|
||||
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class EomtForUniversalSegmentationTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
is_training=True,
|
||||
image_size=40,
|
||||
patch_size=2,
|
||||
num_queries=5,
|
||||
num_register_tokens=19,
|
||||
num_labels=4,
|
||||
hidden_size=8,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.is_training = is_training
|
||||
self.num_queries = num_queries
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_labels = num_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_register_tokens = num_register_tokens
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"image_size": self.image_size,
|
||||
"patch_size": self.patch_size,
|
||||
"num_labels": self.num_labels,
|
||||
"hidden_size": self.hidden_size,
|
||||
"num_attention_heads": self.num_attention_heads,
|
||||
"num_hidden_layers": self.num_hidden_layers,
|
||||
"num_register_tokens": self.num_register_tokens,
|
||||
"num_queries": self.num_queries,
|
||||
"num_blocks": 1,
|
||||
}
|
||||
return EomtConfig(**config)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size]).to(torch_device)
|
||||
|
||||
mask_labels = (
|
||||
torch.rand([self.batch_size, self.num_labels, self.image_size, self.image_size], device=torch_device) > 0.5
|
||||
).float()
|
||||
class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long()
|
||||
|
||||
config = self.get_config()
|
||||
return config, pixel_values, mask_labels, class_labels
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, pixel_values, mask_labels, class_labels = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_training(self):
|
||||
config, pixel_values, mask_labels, class_labels = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"pixel_values": pixel_values, "mask_labels": mask_labels, "class_labels": class_labels}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
|
||||
is_encoder_decoder = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
test_torch_exportable = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = EomtForUniversalSegmentationTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=EomtConfig, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model_with_labels(self):
|
||||
size = (self.model_tester.image_size,) * 2
|
||||
inputs = {
|
||||
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
|
||||
"mask_labels": torch.randn((2, 10, *size), device=torch_device),
|
||||
"class_labels": torch.zeros(2, 10, device=torch_device).long(),
|
||||
}
|
||||
config = self.model_tester.get_config()
|
||||
|
||||
model = EomtForUniversalSegmentation(config).to(torch_device)
|
||||
outputs = model(**inputs)
|
||||
self.assertTrue(outputs.loss is not None)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class._from_config(config, attn_implementation="eager")
|
||||
config = model.config
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# Check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
out_len = len(outputs)
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
@unittest.skip(reason="EoMT does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="EoMT does not have a get_input_embeddings method")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="EoMT is not a generative model")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="EoMT does not use token embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
self.skipTest(reason="ModelTester is not configured to run training tests")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_training()
|
||||
config.return_dict = True
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def test_initialization(self):
|
||||
# Apart from the below params, all other parameters are initialized using kaiming uniform.
|
||||
non_uniform_init_parms = [
|
||||
"layernorm.bias",
|
||||
"layernorm.weight",
|
||||
"norm1.bias",
|
||||
"norm1.weight",
|
||||
"norm2.bias",
|
||||
"norm2.weight",
|
||||
"layer_scale1.lambda1",
|
||||
"layer_scale2.lambda1",
|
||||
"register_tokens",
|
||||
"cls_token",
|
||||
]
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if any(x in name for x in non_uniform_init_parms):
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "tue-mps/coco_panoptic_eomt_large_640"
|
||||
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model = EomtForUniversalSegmentation.from_pretrained(self.model_id, device_map="auto")
|
||||
processor = AutoImageProcessor.from_pretrained(self.model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[ 13.2540, 8.9279, 8.6631, 12.3760, 10.1429],
|
||||
[ -3.4815, -36.4630, -45.5604, -46.8404, -37.5099],
|
||||
[ -6.8689, -44.4206, -62.7591, -59.2928, -47.7035],
|
||||
[ -2.9380, -42.0659, -57.4382, -55.1537, -43.5142],
|
||||
[ -8.4387, -38.5275, -53.1383, -47.0064, -38.9667],
|
||||
]).to(model.device)
|
||||
# fmt: on
|
||||
|
||||
output_slice = outputs.masks_queries_logits[0, 0, :5, :5]
|
||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[-0.6977, -6.4907, -4.1178, -6.5554, -6.6529],
|
||||
[-0.3650, -6.6560, -4.0143, -6.5776, -6.5879],
|
||||
[-0.8820, -6.7175, -3.5334, -6.8569, -6.2415],
|
||||
[ 0.4502, -5.3911, -3.0232, -5.9411, -6.3243],
|
||||
[ 0.3157, -5.6321, -2.6716, -5.5740, -5.5607],
|
||||
]).to(model.device)
|
||||
# fmt: on
|
||||
|
||||
output_slice = outputs.class_queries_logits[0, :5, :5]
|
||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
@slow
|
||||
def test_inference_fp16(self):
|
||||
model = EomtForUniversalSegmentation.from_pretrained(
|
||||
self.model_id, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
processor = AutoImageProcessor.from_pretrained(self.model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||
|
||||
@slow
|
||||
def test_semantic_segmentation_inference(self):
|
||||
model_id = "tue-mps/ade20k_semantic_eomt_large_512"
|
||||
model = EomtForUniversalSegmentation.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||
patch_offsets = inputs.pop("patch_offsets", None)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
|
||||
|
||||
preds = processor.post_process_semantic_segmentation(
|
||||
outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets
|
||||
)
|
||||
|
||||
self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0]))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39],
|
||||
[39, 39, 39, 39, 39, 39, 39, 39, 39, 39]
|
||||
], device=model.device)
|
||||
# fmt: on
|
||||
|
||||
output_slice = preds[0, :10, :10]
|
||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@slow
|
||||
def test_panoptic_segmentation_inference(self):
|
||||
model = EomtForUniversalSegmentation.from_pretrained(self.model_id, device_map="auto")
|
||||
processor = AutoImageProcessor.from_pretrained(self.model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||
|
||||
preds = processor.post_process_panoptic_segmentation(
|
||||
outputs, original_image_sizes=[(image.size[1], image.size[0])]
|
||||
)[0]
|
||||
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, 2, 2, 2, 2, 2],
|
||||
[-1, -1, -1, 2, 2, 2, 2, 2, 2, 2],
|
||||
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
|
||||
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
|
||||
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
|
||||
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
|
||||
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
|
||||
], device=model.device)
|
||||
|
||||
EXPECTED_SEGMENTS_INFO = [
|
||||
{"id": 0, "label_id": 15, "score": 0.99935},
|
||||
{"id": 1, "label_id": 15, "score": 0.998688},
|
||||
{"id": 2, "label_id": 57, "score": 0.954325},
|
||||
{"id": 3, "label_id": 65, "score": 0.997285},
|
||||
{"id": 4, "label_id": 65, "score": 0.99711}
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
output_slice = segmentation[:10, :10]
|
||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
for actual, expected in zip(segments_info, EXPECTED_SEGMENTS_INFO):
|
||||
self.assertEqual(actual["id"], expected["id"])
|
||||
self.assertEqual(actual["label_id"], expected["label_id"])
|
||||
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
|
||||
|
||||
@slow
|
||||
def test_instance_segmentation_inference(self):
|
||||
model_id = "tue-mps/coco_instance_eomt_large_640"
|
||||
model = EomtForUniversalSegmentation.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||
|
||||
preds = processor.post_process_instance_segmentation(
|
||||
outputs, original_image_sizes=[(image.size[1], image.size[0])]
|
||||
)[0]
|
||||
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1., 1.],
|
||||
[ 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]
|
||||
], device=model.device)
|
||||
|
||||
EXPECTED_SEGMENTS_INFO = [
|
||||
{'id': 0, 'label_id': 57, 'score': 0.871247},
|
||||
{'id': 1, 'label_id': 57, 'score': 0.821225},
|
||||
{'id': 2, 'label_id': 15, 'score': 0.976252},
|
||||
{'id': 3, 'label_id': 65, 'score': 0.972960},
|
||||
{'id': 4, 'label_id': 65, 'score': 0.981109},
|
||||
{'id': 5, 'label_id': 15, 'score': 0.972689}
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
output_slice = segmentation[:10, :10]
|
||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
for actual, expected in zip(segments_info, EXPECTED_SEGMENTS_INFO):
|
||||
self.assertEqual(actual["id"], expected["id"])
|
||||
self.assertEqual(actual["label_id"], expected["label_id"])
|
||||
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
|
||||
Reference in New Issue
Block a user