🚨🚨🚨 [eomt] make EoMT compatible with pipeline (#39122)
* Make EoMT compatible with pipeline * Implicit patch offsets * remove patch offsets from arg * Modify tests * Update example * fix proc testcase * Add few more args * add pipeline test suite * fix * docstring fixes * add pipeline test * changes w.r.t review * 🙈 MB * should fix device mismatch * debug * Fixes device mismatch * use decorator * we can split mlp * expected values update --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -74,20 +74,16 @@ inputs = processor(
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Remove Patch Offsets from inputs — only used later for post-processing.
|
||||
patch_offsets = inputs.pop("patch_offsets")
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Prepare the original image size in the format (height, width)
|
||||
original_image_sizes = [(image.height, image.width)]
|
||||
target_sizes = [(image.height, image.width)]
|
||||
|
||||
# Post-process the model outputs to get final segmentation prediction
|
||||
preds = processor.post_process_semantic_segmentation(
|
||||
outputs,
|
||||
patch_offsets=patch_offsets,
|
||||
original_image_sizes=original_image_sizes,
|
||||
target_sizes=target_sizes,
|
||||
)
|
||||
|
||||
# Visualize the segmentation mask
|
||||
@@ -130,12 +126,12 @@ with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Prepare the original image size in the format (height, width)
|
||||
original_image_sizes = [(image.height, image.width)]
|
||||
target_sizes = [(image.height, image.width)]
|
||||
|
||||
# Post-process the model outputs to get final segmentation prediction
|
||||
preds = processor.post_process_instance_segmentation(
|
||||
outputs,
|
||||
original_image_sizes=original_image_sizes,
|
||||
target_sizes=target_sizes,
|
||||
)
|
||||
|
||||
# Visualize the segmentation mask
|
||||
@@ -173,12 +169,12 @@ with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Prepare the original image size in the format (height, width)
|
||||
original_image_sizes = [(image.height, image.width)]
|
||||
target_sizes = [(image.height, image.width)]
|
||||
|
||||
# Post-process the model outputs to get final segmentation prediction
|
||||
preds = processor.post_process_panoptic_segmentation(
|
||||
outputs,
|
||||
original_image_sizes=original_image_sizes,
|
||||
target_sizes=target_sizes,
|
||||
)
|
||||
|
||||
# Visualize the panoptic segmentation mask
|
||||
|
||||
@@ -97,7 +97,7 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, in
|
||||
Computes the output image size given the input image size and the desired output size.
|
||||
|
||||
Args:
|
||||
image_size (`Tuple[int, int]`):
|
||||
image_size (`tuple[int, int]`):
|
||||
The input image size.
|
||||
size (`int`):
|
||||
The desired output size.
|
||||
@@ -531,13 +531,13 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
Image or batch of images to preprocess.
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
||||
instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
|
||||
A mapping between object instance ids and class ids.
|
||||
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
|
||||
Whether to split the input images into overlapping patches for semantic segmentation.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the input images.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing.
|
||||
@@ -550,9 +550,9 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
do_pad (`bool`, *optional*, defaults to `False`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean for normalization. Single value or list for each channel.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation for normalization. Single value or list for each channel.
|
||||
ignore_index (`int`, *optional*):
|
||||
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
||||
@@ -640,7 +640,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
|
||||
if do_split_image and patch_offsets:
|
||||
encoded_inputs["patch_offsets"] = patch_offsets
|
||||
encoded_inputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
@@ -663,8 +663,8 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
each mask.
|
||||
|
||||
Args:
|
||||
pixel_values_list (`List[ImageInput]`):
|
||||
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
||||
pixel_values_list (`list[ImageInput]`):
|
||||
list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
||||
width)`.
|
||||
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
@@ -678,7 +678,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
- 1 for pixels that are real (i.e. **not masked**),
|
||||
- 0 for pixels that are padding (i.e. **masked**).
|
||||
|
||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
||||
instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
|
||||
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
|
||||
instance segmentation map where each pixel represents an instance id. Can be provided as a single
|
||||
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
|
||||
@@ -740,7 +740,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
@@ -750,28 +750,28 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
segmentation_logits (`torch.Tensor`):
|
||||
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
||||
for each image patch.
|
||||
patch_offsets (`List[Tuple[int, int, int]]`):
|
||||
patch_offsets (`list[tuple[int, int, int]]`):
|
||||
A list of tuples where each tuple contains:
|
||||
- `image_index` (int): Index of the original image this patch belongs to.
|
||||
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
||||
- `end` (int): End pixel index of the patch along the long dimension.
|
||||
original_image_sizes (`List[Tuple[int, int]]`):
|
||||
List of original (height, width) dimensions for each image before preprocessing.
|
||||
size (`Dict[str, int]`):
|
||||
target_sizes (`list[tuple[int, int]]`):
|
||||
list of original (height, width) dimensions for each image before preprocessing.
|
||||
size (`dict[str, int]`):
|
||||
A size dict which was used to resize.
|
||||
"""
|
||||
num_classes = segmentation_logits.shape[1]
|
||||
aggregated_logits = []
|
||||
patch_counts = []
|
||||
|
||||
for image_size in original_image_sizes:
|
||||
for image_size in target_sizes:
|
||||
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
|
||||
# Stitch patches back into full-sized logit maps
|
||||
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
||||
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
|
||||
if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
|
||||
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
||||
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
||||
else:
|
||||
@@ -784,7 +784,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
averaged_logits = logit_sum / count.clamp(min=1)
|
||||
resized_logits = F.interpolate(
|
||||
averaged_logits[None, ...],
|
||||
size=original_image_sizes[idx],
|
||||
size=target_sizes[idx],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0]
|
||||
@@ -796,14 +796,14 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
def unpad_image(
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Restores panoptic segmentation logits to their original image resolutions."""
|
||||
|
||||
resized_logits = []
|
||||
|
||||
for idx, original_size in enumerate(original_image_sizes):
|
||||
for idx, original_size in enumerate(target_sizes):
|
||||
target_height, target_width = get_size_with_aspect_ratio(
|
||||
original_size, size["shortest_edge"], size["longest_edge"]
|
||||
)
|
||||
@@ -817,8 +817,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
def post_process_semantic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
size: Optional[dict[str, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Post-processes model outputs into final semantic segmentation prediction."""
|
||||
@@ -827,6 +826,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
patch_offsets = outputs.patch_offsets
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = F.interpolate(
|
||||
@@ -841,15 +841,15 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
|
||||
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||
|
||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
|
||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
||||
|
||||
preds = torch.stack(output_logits).argmax(dim=1)
|
||||
preds = [logit.argmax(dim=0) for logit in output_logits]
|
||||
return preds
|
||||
|
||||
def post_process_panoptic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.8,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
@@ -873,7 +873,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
||||
|
||||
results: list = []
|
||||
@@ -885,7 +885,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
|
||||
# No mask found
|
||||
if mask_probs.shape[0] <= 0:
|
||||
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
|
||||
height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
@@ -897,16 +897,17 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
stuff_classes=stuff_classes,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
|
||||
target_size=target_sizes[i] if target_sizes is not None else None,
|
||||
)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def post_process_instance_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.5,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
):
|
||||
@@ -924,7 +925,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||
|
||||
device = masks_queries_logits.device
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
@@ -946,7 +947,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
pred_scores = scores * mask_scores
|
||||
|
||||
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
|
||||
segmentation = torch.zeros(target_sizes[i], device=device) - 1
|
||||
|
||||
instance_maps, segments = [], []
|
||||
current_segment_id = 0
|
||||
|
||||
@@ -41,6 +41,7 @@ from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
filter_out_non_signature_kwargs,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
@@ -268,7 +269,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
r"""
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The segmentation maps to preprocess for corresponding images.
|
||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
||||
instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
|
||||
A mapping between object instance ids and class ids.
|
||||
"""
|
||||
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
||||
@@ -340,7 +341,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
outputs["class_labels"] = class_labels
|
||||
|
||||
if patch_offsets:
|
||||
outputs["patch_offsets"] = patch_offsets
|
||||
outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -348,7 +349,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
@@ -358,28 +359,28 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
segmentation_logits (`torch.Tensor`):
|
||||
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
||||
for each image patch.
|
||||
patch_offsets (`List[Tuple[int, int, int]]`):
|
||||
patch_offsets (`list[tuple[int, int, int]]`):
|
||||
A list of tuples where each tuple contains:
|
||||
- `image_index` (int): Index of the original image this patch belongs to.
|
||||
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
||||
- `end` (int): End pixel index of the patch along the long dimension.
|
||||
original_image_sizes (`List[Tuple[int, int]]`):
|
||||
List of original (height, width) dimensions for each image before preprocessing.
|
||||
size (`Dict[str, int]`):
|
||||
target_sizes (`list[tuple[int, int]]`):
|
||||
list of original (height, width) dimensions for each image before preprocessing.
|
||||
size (`dict[str, int]`):
|
||||
A size dict which was used to resize.
|
||||
"""
|
||||
num_classes = segmentation_logits.shape[1]
|
||||
aggregated_logits = []
|
||||
patch_counts = []
|
||||
|
||||
for image_size in original_image_sizes:
|
||||
for image_size in target_sizes:
|
||||
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||
|
||||
# Stitch patches back into full-sized logit maps
|
||||
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
||||
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
|
||||
if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
|
||||
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
||||
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
||||
else:
|
||||
@@ -392,7 +393,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
averaged_logits = logit_sum / count.clamp(min=1)
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
averaged_logits[None, ...],
|
||||
size=original_image_sizes[idx],
|
||||
size=target_sizes[idx],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0]
|
||||
@@ -404,14 +405,14 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
def unpad_image(
|
||||
self,
|
||||
segmentation_logits: torch.Tensor,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
size: dict[str, int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Restores panoptic segmentation logits to their original image resolutions."""
|
||||
|
||||
resized_logits = []
|
||||
|
||||
for idx, original_size in enumerate(original_image_sizes):
|
||||
for idx, original_size in enumerate(target_sizes):
|
||||
target_height, target_width = get_size_with_aspect_ratio(
|
||||
original_size, size["shortest_edge"], size["longest_edge"]
|
||||
)
|
||||
@@ -425,8 +426,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
def post_process_semantic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
patch_offsets: list[tuple[int, int, int]],
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
size: Optional[dict[str, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Post-processes model outputs into final semantic segmentation prediction."""
|
||||
@@ -435,6 +435,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||
patch_offsets = outputs.patch_offsets
|
||||
|
||||
output_size = get_target_size(size)
|
||||
masks_queries_logits = torch.nn.functional.interpolate(
|
||||
@@ -449,15 +450,15 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||
|
||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
|
||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
||||
|
||||
preds = torch.stack(output_logits).argmax(dim=1)
|
||||
preds = [logit.argmax(dim=0) for logit in output_logits]
|
||||
return preds
|
||||
|
||||
def post_process_panoptic_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.8,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
@@ -481,7 +482,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
||||
|
||||
results: list = []
|
||||
@@ -493,7 +494,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# No mask found
|
||||
if mask_probs.shape[0] <= 0:
|
||||
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
|
||||
height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
|
||||
segmentation = torch.zeros((height, width)) - 1
|
||||
results.append({"segmentation": segmentation, "segments_info": []})
|
||||
continue
|
||||
@@ -505,16 +506,17 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
stuff_classes=stuff_classes,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
|
||||
target_size=target_sizes[i] if target_sizes is not None else None,
|
||||
)
|
||||
|
||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||
return results
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def post_process_instance_segmentation(
|
||||
self,
|
||||
outputs,
|
||||
original_image_sizes: list[tuple[int, int]],
|
||||
target_sizes: list[tuple[int, int]],
|
||||
threshold: float = 0.8,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
):
|
||||
@@ -532,7 +534,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
||||
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||
|
||||
device = masks_queries_logits.device
|
||||
batch_size = class_queries_logits.shape[0]
|
||||
@@ -554,7 +556,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
pred_scores = scores * mask_scores
|
||||
|
||||
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
|
||||
segmentation = torch.zeros(target_sizes[i], device=device) - 1
|
||||
|
||||
instance_maps, segments = [], []
|
||||
current_segment_id = 0
|
||||
|
||||
@@ -74,6 +74,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
||||
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@@ -82,6 +84,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
patch_offsets: Optional[list[torch.Tensor]] = None
|
||||
|
||||
|
||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||
@@ -996,7 +999,7 @@ class EomtPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "eomt"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["EomtMLP"]
|
||||
_no_split_modules = ["EomtLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
@@ -1097,13 +1100,16 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
|
||||
class_labels: Optional[list[Tensor]] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
patch_offsets: Optional[list[Tensor]] = None,
|
||||
) -> EomtForUniversalSegmentationOutput:
|
||||
r"""
|
||||
mask_labels (`List[torch.Tensor]`, *optional*):
|
||||
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||
class_labels (`List[torch.LongTensor]`, *optional*):
|
||||
mask_labels (`list[torch.Tensor]`, *optional*):
|
||||
list of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||
class_labels (`list[torch.LongTensor]`, *optional*):
|
||||
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
||||
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||
"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@@ -1126,7 +1132,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if idx == self.num_hidden_layers - self.config.num_blocks:
|
||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
|
||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
|
||||
hidden_states = torch.cat((query, hidden_states), dim=1)
|
||||
|
||||
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
||||
@@ -1206,6 +1212,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
patch_offsets=patch_offsets,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@@ -226,6 +226,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
||||
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@@ -234,6 +236,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
patch_offsets: Optional[list[torch.Tensor]] = None
|
||||
|
||||
|
||||
class EomtLoss(Mask2FormerLoss):
|
||||
@@ -368,7 +371,7 @@ class EomtPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "eomt"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["EomtMLP"]
|
||||
_no_split_modules = ["EomtLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
@@ -473,13 +476,16 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
|
||||
class_labels: Optional[list[Tensor]] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
patch_offsets: Optional[list[Tensor]] = None,
|
||||
):
|
||||
r"""
|
||||
mask_labels (`List[torch.Tensor]`, *optional*):
|
||||
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||
class_labels (`List[torch.LongTensor]`, *optional*):
|
||||
mask_labels (`list[torch.Tensor]`, *optional*):
|
||||
list of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||
class_labels (`list[torch.LongTensor]`, *optional*):
|
||||
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
||||
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
||||
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||
"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@@ -502,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if idx == self.num_hidden_layers - self.config.num_blocks:
|
||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
|
||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
|
||||
hidden_states = torch.cat((query, hidden_states), dim=1)
|
||||
|
||||
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
||||
@@ -582,6 +588,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
patch_offsets=patch_offsets,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -84,10 +84,11 @@ class EomtImageProcessingTester:
|
||||
"num_labels": self.num_labels,
|
||||
}
|
||||
|
||||
def prepare_fake_eomt_outputs(self, batch_size):
|
||||
def prepare_fake_eomt_outputs(self, batch_size, patch_offsets=None):
|
||||
return EomtForUniversalSegmentationOutput(
|
||||
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
|
||||
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
|
||||
patch_offsets=patch_offsets,
|
||||
)
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
@@ -263,13 +264,13 @@ class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, do_split_image=True, return_tensors="pt")
|
||||
patch_offsets = inputs.pop("patch_offsets")
|
||||
patch_offsets = inputs["patch_offsets"]
|
||||
|
||||
original_sizes = [image.size[::-1]]
|
||||
target_sizes = [image.size[::-1]]
|
||||
|
||||
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
|
||||
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0])
|
||||
segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes)
|
||||
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0], patch_offsets)
|
||||
segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes)
|
||||
|
||||
self.assertEqual(segmentation[0].shape, (image.height, image.width))
|
||||
|
||||
|
||||
@@ -17,12 +17,13 @@ import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation
|
||||
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation, pipeline
|
||||
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -100,8 +101,9 @@ class EomtForUniversalSegmentationTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase):
|
||||
class EomtForUniversalSegmentationTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-segmentation": EomtForUniversalSegmentation} if is_torch_available() else {}
|
||||
is_encoder_decoder = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
@@ -340,7 +342,6 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||
patch_offsets = inputs.pop("patch_offsets", None)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
@@ -348,11 +349,9 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
|
||||
|
||||
preds = processor.post_process_semantic_segmentation(
|
||||
outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets
|
||||
)
|
||||
preds = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
|
||||
|
||||
self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0]))
|
||||
self.assertTrue(preds.shape == (image.size[1], image.size[0]))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
@@ -369,7 +368,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
], device=model.device)
|
||||
# fmt: on
|
||||
|
||||
output_slice = preds[0, :10, :10]
|
||||
output_slice = preds[:10, :10]
|
||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@slow
|
||||
@@ -387,9 +386,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||
|
||||
preds = processor.post_process_panoptic_segmentation(
|
||||
outputs, original_image_sizes=[(image.size[1], image.size[0])]
|
||||
)[0]
|
||||
preds = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
|
||||
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
||||
|
||||
# fmt: off
|
||||
@@ -438,9 +435,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
|
||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||
|
||||
preds = processor.post_process_instance_segmentation(
|
||||
outputs, original_image_sizes=[(image.size[1], image.size[0])]
|
||||
)[0]
|
||||
preds = processor.post_process_instance_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
|
||||
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
||||
|
||||
# fmt: off
|
||||
@@ -473,3 +468,15 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(actual["id"], expected["id"])
|
||||
self.assertEqual(actual["label_id"], expected["label_id"])
|
||||
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
|
||||
|
||||
@slow
|
||||
def test_segmentation_pipeline(self):
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
pipe = pipeline(model=self.model_id, subtask="panoptic", device=torch_device)
|
||||
output = pipe(image)
|
||||
|
||||
EXPECTED_OUTPUT_LABELS = ["cat", "cat", "couch", "remote", "remote"]
|
||||
|
||||
output_labels = [segment["label"] for segment in output]
|
||||
self.assertEqual(output_labels, EXPECTED_OUTPUT_LABELS)
|
||||
|
||||
Reference in New Issue
Block a user