🚨🚨🚨 [eomt] make EoMT compatible with pipeline (#39122)

* Make EoMT compatible with pipeline

* Implicit patch offsets

* remove patch offsets from arg

* Modify tests

* Update example

* fix proc testcase

* Add few more args

* add pipeline test suite

* fix

* docstring fixes

* add pipeline test

* changes w.r.t review

* 🙈 MB

* should fix device mismatch

* debug

* Fixes device mismatch

* use decorator

* we can split mlp

* expected values update

---------

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Yaswanth Gali
2025-07-02 16:55:26 +05:30
committed by GitHub
parent 4d5822e65d
commit b61023a1b7
7 changed files with 113 additions and 92 deletions

View File

@@ -74,20 +74,16 @@ inputs = processor(
return_tensors="pt", return_tensors="pt",
) )
# Remove Patch Offsets from inputs — only used later for post-processing.
patch_offsets = inputs.pop("patch_offsets")
with torch.inference_mode(): with torch.inference_mode():
outputs = model(**inputs) outputs = model(**inputs)
# Prepare the original image size in the format (height, width) # Prepare the original image size in the format (height, width)
original_image_sizes = [(image.height, image.width)] target_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction # Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_semantic_segmentation( preds = processor.post_process_semantic_segmentation(
outputs, outputs,
patch_offsets=patch_offsets, target_sizes=target_sizes,
original_image_sizes=original_image_sizes,
) )
# Visualize the segmentation mask # Visualize the segmentation mask
@@ -130,12 +126,12 @@ with torch.inference_mode():
outputs = model(**inputs) outputs = model(**inputs)
# Prepare the original image size in the format (height, width) # Prepare the original image size in the format (height, width)
original_image_sizes = [(image.height, image.width)] target_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction # Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_instance_segmentation( preds = processor.post_process_instance_segmentation(
outputs, outputs,
original_image_sizes=original_image_sizes, target_sizes=target_sizes,
) )
# Visualize the segmentation mask # Visualize the segmentation mask
@@ -173,12 +169,12 @@ with torch.inference_mode():
outputs = model(**inputs) outputs = model(**inputs)
# Prepare the original image size in the format (height, width) # Prepare the original image size in the format (height, width)
original_image_sizes = [(image.height, image.width)] target_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction # Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_panoptic_segmentation( preds = processor.post_process_panoptic_segmentation(
outputs, outputs,
original_image_sizes=original_image_sizes, target_sizes=target_sizes,
) )
# Visualize the panoptic segmentation mask # Visualize the panoptic segmentation mask

View File

@@ -97,7 +97,7 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, in
Computes the output image size given the input image size and the desired output size. Computes the output image size given the input image size and the desired output size.
Args: Args:
image_size (`Tuple[int, int]`): image_size (`tuple[int, int]`):
The input image size. The input image size.
size (`int`): size (`int`):
The desired output size. The desired output size.
@@ -531,13 +531,13 @@ class EomtImageProcessor(BaseImageProcessor):
Image or batch of images to preprocess. Image or batch of images to preprocess.
segmentation_maps (`ImageInput`, *optional*): segmentation_maps (`ImageInput`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations. The corresponding semantic segmentation maps with the pixel-wise annotations.
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
A mapping between object instance ids and class ids. A mapping between object instance ids and class ids.
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`): do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
Whether to split the input images into overlapping patches for semantic segmentation. Whether to split the input images into overlapping patches for semantic segmentation.
do_resize (`bool`, *optional*, defaults to `self.do_resize`): do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the input images. Whether to resize the input images.
size (`Dict[str, int]`, *optional*, defaults to `self.size`): size (`dict[str, int]`, *optional*, defaults to `self.size`):
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys. Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`): resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing. Resampling filter to use when resizing.
@@ -550,9 +550,9 @@ class EomtImageProcessor(BaseImageProcessor):
do_pad (`bool`, *optional*, defaults to `False`): do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros. number of patches in the batch. Padding will be applied to the bottom and right with zeros.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Mean for normalization. Single value or list for each channel. Mean for normalization. Single value or list for each channel.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Standard deviation for normalization. Single value or list for each channel. Standard deviation for normalization. Single value or list for each channel.
ignore_index (`int`, *optional*): ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
@@ -640,7 +640,7 @@ class EomtImageProcessor(BaseImageProcessor):
) )
if do_split_image and patch_offsets: if do_split_image and patch_offsets:
encoded_inputs["patch_offsets"] = patch_offsets encoded_inputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
return encoded_inputs return encoded_inputs
@@ -663,8 +663,8 @@ class EomtImageProcessor(BaseImageProcessor):
each mask. each mask.
Args: Args:
pixel_values_list (`List[ImageInput]`): pixel_values_list (`list[ImageInput]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`. width)`.
segmentation_maps (`ImageInput`, *optional*): segmentation_maps (`ImageInput`, *optional*):
@@ -678,7 +678,7 @@ class EomtImageProcessor(BaseImageProcessor):
- 1 for pixels that are real (i.e. **not masked**), - 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**). - 0 for pixels that are padding (i.e. **masked**).
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
instance segmentation map where each pixel represents an instance id. Can be provided as a single instance segmentation map where each pixel represents an instance id. Can be provided as a single
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
@@ -740,7 +740,7 @@ class EomtImageProcessor(BaseImageProcessor):
self, self,
segmentation_logits: torch.Tensor, segmentation_logits: torch.Tensor,
patch_offsets: list[tuple[int, int, int]], patch_offsets: list[tuple[int, int, int]],
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
size: dict[str, int], size: dict[str, int],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
""" """
@@ -750,28 +750,28 @@ class EomtImageProcessor(BaseImageProcessor):
segmentation_logits (`torch.Tensor`): segmentation_logits (`torch.Tensor`):
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
for each image patch. for each image patch.
patch_offsets (`List[Tuple[int, int, int]]`): patch_offsets (`list[tuple[int, int, int]]`):
A list of tuples where each tuple contains: A list of tuples where each tuple contains:
- `image_index` (int): Index of the original image this patch belongs to. - `image_index` (int): Index of the original image this patch belongs to.
- `start` (int): Start pixel index of the patch along the long dimension (height or width). - `start` (int): Start pixel index of the patch along the long dimension (height or width).
- `end` (int): End pixel index of the patch along the long dimension. - `end` (int): End pixel index of the patch along the long dimension.
original_image_sizes (`List[Tuple[int, int]]`): target_sizes (`list[tuple[int, int]]`):
List of original (height, width) dimensions for each image before preprocessing. list of original (height, width) dimensions for each image before preprocessing.
size (`Dict[str, int]`): size (`dict[str, int]`):
A size dict which was used to resize. A size dict which was used to resize.
""" """
num_classes = segmentation_logits.shape[1] num_classes = segmentation_logits.shape[1]
aggregated_logits = [] aggregated_logits = []
patch_counts = [] patch_counts = []
for image_size in original_image_sizes: for image_size in target_sizes:
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
# Stitch patches back into full-sized logit maps # Stitch patches back into full-sized logit maps
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]: if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, patch_start:patch_end, :] += 1 patch_counts[image_idx][:, patch_start:patch_end, :] += 1
else: else:
@@ -784,7 +784,7 @@ class EomtImageProcessor(BaseImageProcessor):
averaged_logits = logit_sum / count.clamp(min=1) averaged_logits = logit_sum / count.clamp(min=1)
resized_logits = F.interpolate( resized_logits = F.interpolate(
averaged_logits[None, ...], averaged_logits[None, ...],
size=original_image_sizes[idx], size=target_sizes[idx],
mode="bilinear", mode="bilinear",
align_corners=False, align_corners=False,
)[0] )[0]
@@ -796,14 +796,14 @@ class EomtImageProcessor(BaseImageProcessor):
def unpad_image( def unpad_image(
self, self,
segmentation_logits: torch.Tensor, segmentation_logits: torch.Tensor,
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
size: dict[str, int], size: dict[str, int],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
"""Restores panoptic segmentation logits to their original image resolutions.""" """Restores panoptic segmentation logits to their original image resolutions."""
resized_logits = [] resized_logits = []
for idx, original_size in enumerate(original_image_sizes): for idx, original_size in enumerate(target_sizes):
target_height, target_width = get_size_with_aspect_ratio( target_height, target_width = get_size_with_aspect_ratio(
original_size, size["shortest_edge"], size["longest_edge"] original_size, size["shortest_edge"], size["longest_edge"]
) )
@@ -817,8 +817,7 @@ class EomtImageProcessor(BaseImageProcessor):
def post_process_semantic_segmentation( def post_process_semantic_segmentation(
self, self,
outputs, outputs,
patch_offsets: list[tuple[int, int, int]], target_sizes: list[tuple[int, int]],
original_image_sizes: list[tuple[int, int]],
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Post-processes model outputs into final semantic segmentation prediction.""" """Post-processes model outputs into final semantic segmentation prediction."""
@@ -827,6 +826,7 @@ class EomtImageProcessor(BaseImageProcessor):
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
patch_offsets = outputs.patch_offsets
output_size = get_target_size(size) output_size = get_target_size(size)
masks_queries_logits = F.interpolate( masks_queries_logits = F.interpolate(
@@ -841,15 +841,15 @@ class EomtImageProcessor(BaseImageProcessor):
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size) output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
preds = torch.stack(output_logits).argmax(dim=1) preds = [logit.argmax(dim=0) for logit in output_logits]
return preds return preds
def post_process_panoptic_segmentation( def post_process_panoptic_segmentation(
self, self,
outputs, outputs,
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
threshold: float = 0.8, threshold: float = 0.8,
mask_threshold: float = 0.5, mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8, overlap_mask_area_threshold: float = 0.8,
@@ -873,7 +873,7 @@ class EomtImageProcessor(BaseImageProcessor):
mode="bilinear", mode="bilinear",
) )
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1) pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
results: list = [] results: list = []
@@ -885,7 +885,7 @@ class EomtImageProcessor(BaseImageProcessor):
# No mask found # No mask found
if mask_probs.shape[0] <= 0: if mask_probs.shape[0] <= 0:
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:] height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
segmentation = torch.zeros((height, width)) - 1 segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []}) results.append({"segmentation": segmentation, "segments_info": []})
continue continue
@@ -897,16 +897,17 @@ class EomtImageProcessor(BaseImageProcessor):
stuff_classes=stuff_classes, stuff_classes=stuff_classes,
mask_threshold=mask_threshold, mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold, overlap_mask_area_threshold=overlap_mask_area_threshold,
target_size=original_image_sizes[i] if original_image_sizes is not None else None, target_size=target_sizes[i] if target_sizes is not None else None,
) )
results.append({"segmentation": segmentation, "segments_info": segments}) results.append({"segmentation": segmentation, "segments_info": segments})
return results return results
@filter_out_non_signature_kwargs()
def post_process_instance_segmentation( def post_process_instance_segmentation(
self, self,
outputs, outputs,
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
threshold: float = 0.5, threshold: float = 0.5,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
): ):
@@ -924,7 +925,7 @@ class EomtImageProcessor(BaseImageProcessor):
mode="bilinear", mode="bilinear",
) )
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
device = masks_queries_logits.device device = masks_queries_logits.device
batch_size = class_queries_logits.shape[0] batch_size = class_queries_logits.shape[0]
@@ -946,7 +947,7 @@ class EomtImageProcessor(BaseImageProcessor):
) )
pred_scores = scores * mask_scores pred_scores = scores * mask_scores
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1 segmentation = torch.zeros(target_sizes[i], device=device) - 1
instance_maps, segments = [], [] instance_maps, segments = [], []
current_segment_id = 0 current_segment_id = 0

View File

@@ -41,6 +41,7 @@ from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
TensorType, TensorType,
auto_docstring, auto_docstring,
filter_out_non_signature_kwargs,
is_torch_available, is_torch_available,
is_torchvision_available, is_torchvision_available,
is_torchvision_v2_available, is_torchvision_v2_available,
@@ -268,7 +269,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
r""" r"""
segmentation_maps (`ImageInput`, *optional*): segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess for corresponding images. The segmentation maps to preprocess for corresponding images.
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
A mapping between object instance ids and class ids. A mapping between object instance ids and class ids.
""" """
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
@@ -340,7 +341,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
outputs["class_labels"] = class_labels outputs["class_labels"] = class_labels
if patch_offsets: if patch_offsets:
outputs["patch_offsets"] = patch_offsets outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
return outputs return outputs
@@ -348,7 +349,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
self, self,
segmentation_logits: torch.Tensor, segmentation_logits: torch.Tensor,
patch_offsets: list[tuple[int, int, int]], patch_offsets: list[tuple[int, int, int]],
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
size: dict[str, int], size: dict[str, int],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
""" """
@@ -358,28 +359,28 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
segmentation_logits (`torch.Tensor`): segmentation_logits (`torch.Tensor`):
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
for each image patch. for each image patch.
patch_offsets (`List[Tuple[int, int, int]]`): patch_offsets (`list[tuple[int, int, int]]`):
A list of tuples where each tuple contains: A list of tuples where each tuple contains:
- `image_index` (int): Index of the original image this patch belongs to. - `image_index` (int): Index of the original image this patch belongs to.
- `start` (int): Start pixel index of the patch along the long dimension (height or width). - `start` (int): Start pixel index of the patch along the long dimension (height or width).
- `end` (int): End pixel index of the patch along the long dimension. - `end` (int): End pixel index of the patch along the long dimension.
original_image_sizes (`List[Tuple[int, int]]`): target_sizes (`list[tuple[int, int]]`):
List of original (height, width) dimensions for each image before preprocessing. list of original (height, width) dimensions for each image before preprocessing.
size (`Dict[str, int]`): size (`dict[str, int]`):
A size dict which was used to resize. A size dict which was used to resize.
""" """
num_classes = segmentation_logits.shape[1] num_classes = segmentation_logits.shape[1]
aggregated_logits = [] aggregated_logits = []
patch_counts = [] patch_counts = []
for image_size in original_image_sizes: for image_size in target_sizes:
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
# Stitch patches back into full-sized logit maps # Stitch patches back into full-sized logit maps
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]: if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, patch_start:patch_end, :] += 1 patch_counts[image_idx][:, patch_start:patch_end, :] += 1
else: else:
@@ -392,7 +393,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
averaged_logits = logit_sum / count.clamp(min=1) averaged_logits = logit_sum / count.clamp(min=1)
resized_logits = torch.nn.functional.interpolate( resized_logits = torch.nn.functional.interpolate(
averaged_logits[None, ...], averaged_logits[None, ...],
size=original_image_sizes[idx], size=target_sizes[idx],
mode="bilinear", mode="bilinear",
align_corners=False, align_corners=False,
)[0] )[0]
@@ -404,14 +405,14 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
def unpad_image( def unpad_image(
self, self,
segmentation_logits: torch.Tensor, segmentation_logits: torch.Tensor,
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
size: dict[str, int], size: dict[str, int],
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
"""Restores panoptic segmentation logits to their original image resolutions.""" """Restores panoptic segmentation logits to their original image resolutions."""
resized_logits = [] resized_logits = []
for idx, original_size in enumerate(original_image_sizes): for idx, original_size in enumerate(target_sizes):
target_height, target_width = get_size_with_aspect_ratio( target_height, target_width = get_size_with_aspect_ratio(
original_size, size["shortest_edge"], size["longest_edge"] original_size, size["shortest_edge"], size["longest_edge"]
) )
@@ -425,8 +426,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
def post_process_semantic_segmentation( def post_process_semantic_segmentation(
self, self,
outputs, outputs,
patch_offsets: list[tuple[int, int, int]], target_sizes: list[tuple[int, int]],
original_image_sizes: list[tuple[int, int]],
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Post-processes model outputs into final semantic segmentation prediction.""" """Post-processes model outputs into final semantic segmentation prediction."""
@@ -435,6 +435,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
patch_offsets = outputs.patch_offsets
output_size = get_target_size(size) output_size = get_target_size(size)
masks_queries_logits = torch.nn.functional.interpolate( masks_queries_logits = torch.nn.functional.interpolate(
@@ -449,15 +450,15 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size) output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
preds = torch.stack(output_logits).argmax(dim=1) preds = [logit.argmax(dim=0) for logit in output_logits]
return preds return preds
def post_process_panoptic_segmentation( def post_process_panoptic_segmentation(
self, self,
outputs, outputs,
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
threshold: float = 0.8, threshold: float = 0.8,
mask_threshold: float = 0.5, mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8, overlap_mask_area_threshold: float = 0.8,
@@ -481,7 +482,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
mode="bilinear", mode="bilinear",
) )
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1) pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
results: list = [] results: list = []
@@ -493,7 +494,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
# No mask found # No mask found
if mask_probs.shape[0] <= 0: if mask_probs.shape[0] <= 0:
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:] height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
segmentation = torch.zeros((height, width)) - 1 segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []}) results.append({"segmentation": segmentation, "segments_info": []})
continue continue
@@ -505,16 +506,17 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
stuff_classes=stuff_classes, stuff_classes=stuff_classes,
mask_threshold=mask_threshold, mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold, overlap_mask_area_threshold=overlap_mask_area_threshold,
target_size=original_image_sizes[i] if original_image_sizes is not None else None, target_size=target_sizes[i] if target_sizes is not None else None,
) )
results.append({"segmentation": segmentation, "segments_info": segments}) results.append({"segmentation": segmentation, "segments_info": segments})
return results return results
@filter_out_non_signature_kwargs()
def post_process_instance_segmentation( def post_process_instance_segmentation(
self, self,
outputs, outputs,
original_image_sizes: list[tuple[int, int]], target_sizes: list[tuple[int, int]],
threshold: float = 0.8, threshold: float = 0.8,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
): ):
@@ -532,7 +534,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
mode="bilinear", mode="bilinear",
) )
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
device = masks_queries_logits.device device = masks_queries_logits.device
batch_size = class_queries_logits.shape[0] batch_size = class_queries_logits.shape[0]
@@ -554,7 +556,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
) )
pred_scores = scores * mask_scores pred_scores = scores * mask_scores
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1 segmentation = torch.zeros(target_sizes[i], device=device) - 1
instance_maps, segments = [], [] instance_maps, segments = [], []
current_segment_id = 0 current_segment_id = 0

View File

@@ -74,6 +74,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
patch_offsets (`list[torch.Tensor]`, *optional*):
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@@ -82,6 +84,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
last_hidden_state: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None
patch_offsets: Optional[list[torch.Tensor]] = None
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
@@ -996,7 +999,7 @@ class EomtPreTrainedModel(PreTrainedModel):
base_model_prefix = "eomt" base_model_prefix = "eomt"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = ["EomtMLP"] _no_split_modules = ["EomtLayer"]
_supports_sdpa = True _supports_sdpa = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
@@ -1097,13 +1100,16 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
class_labels: Optional[list[Tensor]] = None, class_labels: Optional[list[Tensor]] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
patch_offsets: Optional[list[Tensor]] = None,
) -> EomtForUniversalSegmentationOutput: ) -> EomtForUniversalSegmentationOutput:
r""" r"""
mask_labels (`List[torch.Tensor]`, *optional*): mask_labels (`list[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model list of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*): class_labels (`list[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
patch_offsets (`list[torch.Tensor]`, *optional*):
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
""" """
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1126,7 +1132,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks: if idx == self.num_hidden_layers - self.config.num_blocks:
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1) query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1) hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and ( if idx >= self.num_hidden_layers - self.config.num_blocks and (
@@ -1206,6 +1212,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
patch_offsets=patch_offsets,
) )
def get_input_embeddings(self): def get_input_embeddings(self):

View File

@@ -226,6 +226,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
patch_offsets (`list[torch.Tensor]`, *optional*):
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@@ -234,6 +236,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
last_hidden_state: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None
patch_offsets: Optional[list[torch.Tensor]] = None
class EomtLoss(Mask2FormerLoss): class EomtLoss(Mask2FormerLoss):
@@ -368,7 +371,7 @@ class EomtPreTrainedModel(PreTrainedModel):
base_model_prefix = "eomt" base_model_prefix = "eomt"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = ["EomtMLP"] _no_split_modules = ["EomtLayer"]
_supports_sdpa = True _supports_sdpa = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
@@ -473,13 +476,16 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
class_labels: Optional[list[Tensor]] = None, class_labels: Optional[list[Tensor]] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
patch_offsets: Optional[list[Tensor]] = None,
): ):
r""" r"""
mask_labels (`List[torch.Tensor]`, *optional*): mask_labels (`list[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model list of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*): class_labels (`list[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
patch_offsets (`list[torch.Tensor]`, *optional*):
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
""" """
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -502,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks: if idx == self.num_hidden_layers - self.config.num_blocks:
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1) query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1) hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and ( if idx >= self.num_hidden_layers - self.config.num_blocks and (
@@ -582,6 +588,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
patch_offsets=patch_offsets,
) )

View File

@@ -84,10 +84,11 @@ class EomtImageProcessingTester:
"num_labels": self.num_labels, "num_labels": self.num_labels,
} }
def prepare_fake_eomt_outputs(self, batch_size): def prepare_fake_eomt_outputs(self, batch_size, patch_offsets=None):
return EomtForUniversalSegmentationOutput( return EomtForUniversalSegmentationOutput(
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)), masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)), class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
patch_offsets=patch_offsets,
) )
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
@@ -263,13 +264,13 @@ class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, do_split_image=True, return_tensors="pt") inputs = processor(images=image, do_split_image=True, return_tensors="pt")
patch_offsets = inputs.pop("patch_offsets") patch_offsets = inputs["patch_offsets"]
original_sizes = [image.size[::-1]] target_sizes = [image.size[::-1]]
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image. # For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0]) outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0], patch_offsets)
segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes) segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes)
self.assertEqual(segmentation[0].shape, (image.height, image.width)) self.assertEqual(segmentation[0].shape, (image.height, image.width))

View File

@@ -17,12 +17,13 @@ import unittest
import requests import requests
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation, pipeline
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
@@ -100,8 +101,9 @@ class EomtForUniversalSegmentationTester:
@require_torch @require_torch
class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase): class EomtForUniversalSegmentationTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else () all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
pipeline_model_mapping = {"image-segmentation": EomtForUniversalSegmentation} if is_torch_available() else {}
is_encoder_decoder = False is_encoder_decoder = False
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
@@ -340,7 +342,6 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device) inputs = processor(images=image, return_tensors="pt").to(model.device)
patch_offsets = inputs.pop("patch_offsets", None)
with torch.inference_mode(): with torch.inference_mode():
outputs = model(**inputs) outputs = model(**inputs)
@@ -348,11 +349,9 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151)) self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128)) self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
preds = processor.post_process_semantic_segmentation( preds = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets
)
self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0])) self.assertTrue(preds.shape == (image.size[1], image.size[0]))
# fmt: off # fmt: off
EXPECTED_SLICE = torch.tensor([ EXPECTED_SLICE = torch.tensor([
@@ -369,7 +368,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
], device=model.device) ], device=model.device)
# fmt: on # fmt: on
output_slice = preds[0, :10, :10] output_slice = preds[:10, :10]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2) torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
@slow @slow
@@ -387,9 +386,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134)) self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160)) self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
preds = processor.post_process_panoptic_segmentation( preds = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
outputs, original_image_sizes=[(image.size[1], image.size[0])]
)[0]
segmentation, segments_info = preds["segmentation"], preds["segments_info"] segmentation, segments_info = preds["segmentation"], preds["segments_info"]
# fmt: off # fmt: off
@@ -438,9 +435,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81)) self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160)) self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
preds = processor.post_process_instance_segmentation( preds = processor.post_process_instance_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
outputs, original_image_sizes=[(image.size[1], image.size[0])]
)[0]
segmentation, segments_info = preds["segmentation"], preds["segments_info"] segmentation, segments_info = preds["segmentation"], preds["segments_info"]
# fmt: off # fmt: off
@@ -473,3 +468,15 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertEqual(actual["id"], expected["id"]) self.assertEqual(actual["id"], expected["id"])
self.assertEqual(actual["label_id"], expected["label_id"]) self.assertEqual(actual["label_id"], expected["label_id"])
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3) self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
@slow
def test_segmentation_pipeline(self):
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
pipe = pipeline(model=self.model_id, subtask="panoptic", device=torch_device)
output = pipe(image)
EXPECTED_OUTPUT_LABELS = ["cat", "cat", "couch", "remote", "remote"]
output_labels = [segment["label"] for segment in output]
self.assertEqual(output_labels, EXPECTED_OUTPUT_LABELS)