🚨🚨🚨 [eomt] make EoMT compatible with pipeline (#39122)
* Make EoMT compatible with pipeline * Implicit patch offsets * remove patch offsets from arg * Modify tests * Update example * fix proc testcase * Add few more args * add pipeline test suite * fix * docstring fixes * add pipeline test * changes w.r.t review * 🙈 MB * should fix device mismatch * debug * Fixes device mismatch * use decorator * we can split mlp * expected values update --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -74,20 +74,16 @@ inputs = processor(
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove Patch Offsets from inputs — only used later for post-processing.
|
|
||||||
patch_offsets = inputs.pop("patch_offsets")
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# Prepare the original image size in the format (height, width)
|
# Prepare the original image size in the format (height, width)
|
||||||
original_image_sizes = [(image.height, image.width)]
|
target_sizes = [(image.height, image.width)]
|
||||||
|
|
||||||
# Post-process the model outputs to get final segmentation prediction
|
# Post-process the model outputs to get final segmentation prediction
|
||||||
preds = processor.post_process_semantic_segmentation(
|
preds = processor.post_process_semantic_segmentation(
|
||||||
outputs,
|
outputs,
|
||||||
patch_offsets=patch_offsets,
|
target_sizes=target_sizes,
|
||||||
original_image_sizes=original_image_sizes,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Visualize the segmentation mask
|
# Visualize the segmentation mask
|
||||||
@@ -130,12 +126,12 @@ with torch.inference_mode():
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# Prepare the original image size in the format (height, width)
|
# Prepare the original image size in the format (height, width)
|
||||||
original_image_sizes = [(image.height, image.width)]
|
target_sizes = [(image.height, image.width)]
|
||||||
|
|
||||||
# Post-process the model outputs to get final segmentation prediction
|
# Post-process the model outputs to get final segmentation prediction
|
||||||
preds = processor.post_process_instance_segmentation(
|
preds = processor.post_process_instance_segmentation(
|
||||||
outputs,
|
outputs,
|
||||||
original_image_sizes=original_image_sizes,
|
target_sizes=target_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Visualize the segmentation mask
|
# Visualize the segmentation mask
|
||||||
@@ -173,12 +169,12 @@ with torch.inference_mode():
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# Prepare the original image size in the format (height, width)
|
# Prepare the original image size in the format (height, width)
|
||||||
original_image_sizes = [(image.height, image.width)]
|
target_sizes = [(image.height, image.width)]
|
||||||
|
|
||||||
# Post-process the model outputs to get final segmentation prediction
|
# Post-process the model outputs to get final segmentation prediction
|
||||||
preds = processor.post_process_panoptic_segmentation(
|
preds = processor.post_process_panoptic_segmentation(
|
||||||
outputs,
|
outputs,
|
||||||
original_image_sizes=original_image_sizes,
|
target_sizes=target_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Visualize the panoptic segmentation mask
|
# Visualize the panoptic segmentation mask
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, in
|
|||||||
Computes the output image size given the input image size and the desired output size.
|
Computes the output image size given the input image size and the desired output size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_size (`Tuple[int, int]`):
|
image_size (`tuple[int, int]`):
|
||||||
The input image size.
|
The input image size.
|
||||||
size (`int`):
|
size (`int`):
|
||||||
The desired output size.
|
The desired output size.
|
||||||
@@ -531,13 +531,13 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
Image or batch of images to preprocess.
|
Image or batch of images to preprocess.
|
||||||
segmentation_maps (`ImageInput`, *optional*):
|
segmentation_maps (`ImageInput`, *optional*):
|
||||||
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
||||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
|
||||||
A mapping between object instance ids and class ids.
|
A mapping between object instance ids and class ids.
|
||||||
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
|
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
|
||||||
Whether to split the input images into overlapping patches for semantic segmentation.
|
Whether to split the input images into overlapping patches for semantic segmentation.
|
||||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
Whether to resize the input images.
|
Whether to resize the input images.
|
||||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
|
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
|
||||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
Resampling filter to use when resizing.
|
Resampling filter to use when resizing.
|
||||||
@@ -550,9 +550,9 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
do_pad (`bool`, *optional*, defaults to `False`):
|
do_pad (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
Mean for normalization. Single value or list for each channel.
|
Mean for normalization. Single value or list for each channel.
|
||||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||||
Standard deviation for normalization. Single value or list for each channel.
|
Standard deviation for normalization. Single value or list for each channel.
|
||||||
ignore_index (`int`, *optional*):
|
ignore_index (`int`, *optional*):
|
||||||
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
||||||
@@ -640,7 +640,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if do_split_image and patch_offsets:
|
if do_split_image and patch_offsets:
|
||||||
encoded_inputs["patch_offsets"] = patch_offsets
|
encoded_inputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
|
||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|
||||||
@@ -663,8 +663,8 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
each mask.
|
each mask.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pixel_values_list (`List[ImageInput]`):
|
pixel_values_list (`list[ImageInput]`):
|
||||||
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
||||||
width)`.
|
width)`.
|
||||||
|
|
||||||
segmentation_maps (`ImageInput`, *optional*):
|
segmentation_maps (`ImageInput`, *optional*):
|
||||||
@@ -678,7 +678,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
- 1 for pixels that are real (i.e. **not masked**),
|
- 1 for pixels that are real (i.e. **not masked**),
|
||||||
- 0 for pixels that are padding (i.e. **masked**).
|
- 0 for pixels that are padding (i.e. **masked**).
|
||||||
|
|
||||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
|
||||||
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
|
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
|
||||||
instance segmentation map where each pixel represents an instance id. Can be provided as a single
|
instance segmentation map where each pixel represents an instance id. Can be provided as a single
|
||||||
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
|
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
|
||||||
@@ -740,7 +740,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
self,
|
self,
|
||||||
segmentation_logits: torch.Tensor,
|
segmentation_logits: torch.Tensor,
|
||||||
patch_offsets: list[tuple[int, int, int]],
|
patch_offsets: list[tuple[int, int, int]],
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
size: dict[str, int],
|
size: dict[str, int],
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -750,28 +750,28 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
segmentation_logits (`torch.Tensor`):
|
segmentation_logits (`torch.Tensor`):
|
||||||
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
||||||
for each image patch.
|
for each image patch.
|
||||||
patch_offsets (`List[Tuple[int, int, int]]`):
|
patch_offsets (`list[tuple[int, int, int]]`):
|
||||||
A list of tuples where each tuple contains:
|
A list of tuples where each tuple contains:
|
||||||
- `image_index` (int): Index of the original image this patch belongs to.
|
- `image_index` (int): Index of the original image this patch belongs to.
|
||||||
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
||||||
- `end` (int): End pixel index of the patch along the long dimension.
|
- `end` (int): End pixel index of the patch along the long dimension.
|
||||||
original_image_sizes (`List[Tuple[int, int]]`):
|
target_sizes (`list[tuple[int, int]]`):
|
||||||
List of original (height, width) dimensions for each image before preprocessing.
|
list of original (height, width) dimensions for each image before preprocessing.
|
||||||
size (`Dict[str, int]`):
|
size (`dict[str, int]`):
|
||||||
A size dict which was used to resize.
|
A size dict which was used to resize.
|
||||||
"""
|
"""
|
||||||
num_classes = segmentation_logits.shape[1]
|
num_classes = segmentation_logits.shape[1]
|
||||||
aggregated_logits = []
|
aggregated_logits = []
|
||||||
patch_counts = []
|
patch_counts = []
|
||||||
|
|
||||||
for image_size in original_image_sizes:
|
for image_size in target_sizes:
|
||||||
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||||
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||||
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||||
|
|
||||||
# Stitch patches back into full-sized logit maps
|
# Stitch patches back into full-sized logit maps
|
||||||
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
||||||
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
|
if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
|
||||||
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
||||||
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
||||||
else:
|
else:
|
||||||
@@ -784,7 +784,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
averaged_logits = logit_sum / count.clamp(min=1)
|
averaged_logits = logit_sum / count.clamp(min=1)
|
||||||
resized_logits = F.interpolate(
|
resized_logits = F.interpolate(
|
||||||
averaged_logits[None, ...],
|
averaged_logits[None, ...],
|
||||||
size=original_image_sizes[idx],
|
size=target_sizes[idx],
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
)[0]
|
)[0]
|
||||||
@@ -796,14 +796,14 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
def unpad_image(
|
def unpad_image(
|
||||||
self,
|
self,
|
||||||
segmentation_logits: torch.Tensor,
|
segmentation_logits: torch.Tensor,
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
size: dict[str, int],
|
size: dict[str, int],
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
"""Restores panoptic segmentation logits to their original image resolutions."""
|
"""Restores panoptic segmentation logits to their original image resolutions."""
|
||||||
|
|
||||||
resized_logits = []
|
resized_logits = []
|
||||||
|
|
||||||
for idx, original_size in enumerate(original_image_sizes):
|
for idx, original_size in enumerate(target_sizes):
|
||||||
target_height, target_width = get_size_with_aspect_ratio(
|
target_height, target_width = get_size_with_aspect_ratio(
|
||||||
original_size, size["shortest_edge"], size["longest_edge"]
|
original_size, size["shortest_edge"], size["longest_edge"]
|
||||||
)
|
)
|
||||||
@@ -817,8 +817,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
def post_process_semantic_segmentation(
|
def post_process_semantic_segmentation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
patch_offsets: list[tuple[int, int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
original_image_sizes: list[tuple[int, int]],
|
|
||||||
size: Optional[dict[str, int]] = None,
|
size: Optional[dict[str, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Post-processes model outputs into final semantic segmentation prediction."""
|
"""Post-processes model outputs into final semantic segmentation prediction."""
|
||||||
@@ -827,6 +826,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||||
|
patch_offsets = outputs.patch_offsets
|
||||||
|
|
||||||
output_size = get_target_size(size)
|
output_size = get_target_size(size)
|
||||||
masks_queries_logits = F.interpolate(
|
masks_queries_logits = F.interpolate(
|
||||||
@@ -841,15 +841,15 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||||
|
|
||||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
|
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
||||||
|
|
||||||
preds = torch.stack(output_logits).argmax(dim=1)
|
preds = [logit.argmax(dim=0) for logit in output_logits]
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def post_process_panoptic_segmentation(
|
def post_process_panoptic_segmentation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
mask_threshold: float = 0.5,
|
mask_threshold: float = 0.5,
|
||||||
overlap_mask_area_threshold: float = 0.8,
|
overlap_mask_area_threshold: float = 0.8,
|
||||||
@@ -873,7 +873,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||||
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
||||||
|
|
||||||
results: list = []
|
results: list = []
|
||||||
@@ -885,7 +885,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
# No mask found
|
# No mask found
|
||||||
if mask_probs.shape[0] <= 0:
|
if mask_probs.shape[0] <= 0:
|
||||||
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
|
height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
|
||||||
segmentation = torch.zeros((height, width)) - 1
|
segmentation = torch.zeros((height, width)) - 1
|
||||||
results.append({"segmentation": segmentation, "segments_info": []})
|
results.append({"segmentation": segmentation, "segments_info": []})
|
||||||
continue
|
continue
|
||||||
@@ -897,16 +897,17 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
stuff_classes=stuff_classes,
|
stuff_classes=stuff_classes,
|
||||||
mask_threshold=mask_threshold,
|
mask_threshold=mask_threshold,
|
||||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||||
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
|
target_size=target_sizes[i] if target_sizes is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def post_process_instance_segmentation(
|
def post_process_instance_segmentation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
size: Optional[dict[str, int]] = None,
|
size: Optional[dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
@@ -924,7 +925,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||||
|
|
||||||
device = masks_queries_logits.device
|
device = masks_queries_logits.device
|
||||||
batch_size = class_queries_logits.shape[0]
|
batch_size = class_queries_logits.shape[0]
|
||||||
@@ -946,7 +947,7 @@ class EomtImageProcessor(BaseImageProcessor):
|
|||||||
)
|
)
|
||||||
pred_scores = scores * mask_scores
|
pred_scores = scores * mask_scores
|
||||||
|
|
||||||
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
|
segmentation = torch.zeros(target_sizes[i], device=device) - 1
|
||||||
|
|
||||||
instance_maps, segments = [], []
|
instance_maps, segments = [], []
|
||||||
current_segment_id = 0
|
current_segment_id = 0
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from ...processing_utils import Unpack
|
|||||||
from ...utils import (
|
from ...utils import (
|
||||||
TensorType,
|
TensorType,
|
||||||
auto_docstring,
|
auto_docstring,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_torchvision_v2_available,
|
is_torchvision_v2_available,
|
||||||
@@ -268,7 +269,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
r"""
|
r"""
|
||||||
segmentation_maps (`ImageInput`, *optional*):
|
segmentation_maps (`ImageInput`, *optional*):
|
||||||
The segmentation maps to preprocess for corresponding images.
|
The segmentation maps to preprocess for corresponding images.
|
||||||
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
|
||||||
A mapping between object instance ids and class ids.
|
A mapping between object instance ids and class ids.
|
||||||
"""
|
"""
|
||||||
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
||||||
@@ -340,7 +341,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
outputs["class_labels"] = class_labels
|
outputs["class_labels"] = class_labels
|
||||||
|
|
||||||
if patch_offsets:
|
if patch_offsets:
|
||||||
outputs["patch_offsets"] = patch_offsets
|
outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -348,7 +349,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
self,
|
self,
|
||||||
segmentation_logits: torch.Tensor,
|
segmentation_logits: torch.Tensor,
|
||||||
patch_offsets: list[tuple[int, int, int]],
|
patch_offsets: list[tuple[int, int, int]],
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
size: dict[str, int],
|
size: dict[str, int],
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -358,28 +359,28 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
segmentation_logits (`torch.Tensor`):
|
segmentation_logits (`torch.Tensor`):
|
||||||
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
|
||||||
for each image patch.
|
for each image patch.
|
||||||
patch_offsets (`List[Tuple[int, int, int]]`):
|
patch_offsets (`list[tuple[int, int, int]]`):
|
||||||
A list of tuples where each tuple contains:
|
A list of tuples where each tuple contains:
|
||||||
- `image_index` (int): Index of the original image this patch belongs to.
|
- `image_index` (int): Index of the original image this patch belongs to.
|
||||||
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
|
||||||
- `end` (int): End pixel index of the patch along the long dimension.
|
- `end` (int): End pixel index of the patch along the long dimension.
|
||||||
original_image_sizes (`List[Tuple[int, int]]`):
|
target_sizes (`list[tuple[int, int]]`):
|
||||||
List of original (height, width) dimensions for each image before preprocessing.
|
list of original (height, width) dimensions for each image before preprocessing.
|
||||||
size (`Dict[str, int]`):
|
size (`dict[str, int]`):
|
||||||
A size dict which was used to resize.
|
A size dict which was used to resize.
|
||||||
"""
|
"""
|
||||||
num_classes = segmentation_logits.shape[1]
|
num_classes = segmentation_logits.shape[1]
|
||||||
aggregated_logits = []
|
aggregated_logits = []
|
||||||
patch_counts = []
|
patch_counts = []
|
||||||
|
|
||||||
for image_size in original_image_sizes:
|
for image_size in target_sizes:
|
||||||
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
|
||||||
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||||
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
|
||||||
|
|
||||||
# Stitch patches back into full-sized logit maps
|
# Stitch patches back into full-sized logit maps
|
||||||
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
|
||||||
if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
|
if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
|
||||||
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
|
||||||
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
|
||||||
else:
|
else:
|
||||||
@@ -392,7 +393,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
averaged_logits = logit_sum / count.clamp(min=1)
|
averaged_logits = logit_sum / count.clamp(min=1)
|
||||||
resized_logits = torch.nn.functional.interpolate(
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
averaged_logits[None, ...],
|
averaged_logits[None, ...],
|
||||||
size=original_image_sizes[idx],
|
size=target_sizes[idx],
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
)[0]
|
)[0]
|
||||||
@@ -404,14 +405,14 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
def unpad_image(
|
def unpad_image(
|
||||||
self,
|
self,
|
||||||
segmentation_logits: torch.Tensor,
|
segmentation_logits: torch.Tensor,
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
size: dict[str, int],
|
size: dict[str, int],
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
"""Restores panoptic segmentation logits to their original image resolutions."""
|
"""Restores panoptic segmentation logits to their original image resolutions."""
|
||||||
|
|
||||||
resized_logits = []
|
resized_logits = []
|
||||||
|
|
||||||
for idx, original_size in enumerate(original_image_sizes):
|
for idx, original_size in enumerate(target_sizes):
|
||||||
target_height, target_width = get_size_with_aspect_ratio(
|
target_height, target_width = get_size_with_aspect_ratio(
|
||||||
original_size, size["shortest_edge"], size["longest_edge"]
|
original_size, size["shortest_edge"], size["longest_edge"]
|
||||||
)
|
)
|
||||||
@@ -425,8 +426,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
def post_process_semantic_segmentation(
|
def post_process_semantic_segmentation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
patch_offsets: list[tuple[int, int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
original_image_sizes: list[tuple[int, int]],
|
|
||||||
size: Optional[dict[str, int]] = None,
|
size: Optional[dict[str, int]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Post-processes model outputs into final semantic segmentation prediction."""
|
"""Post-processes model outputs into final semantic segmentation prediction."""
|
||||||
@@ -435,6 +435,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
||||||
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
||||||
|
patch_offsets = outputs.patch_offsets
|
||||||
|
|
||||||
output_size = get_target_size(size)
|
output_size = get_target_size(size)
|
||||||
masks_queries_logits = torch.nn.functional.interpolate(
|
masks_queries_logits = torch.nn.functional.interpolate(
|
||||||
@@ -449,15 +450,15 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
||||||
|
|
||||||
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
|
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
||||||
|
|
||||||
preds = torch.stack(output_logits).argmax(dim=1)
|
preds = [logit.argmax(dim=0) for logit in output_logits]
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def post_process_panoptic_segmentation(
|
def post_process_panoptic_segmentation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
mask_threshold: float = 0.5,
|
mask_threshold: float = 0.5,
|
||||||
overlap_mask_area_threshold: float = 0.8,
|
overlap_mask_area_threshold: float = 0.8,
|
||||||
@@ -481,7 +482,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||||
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
|
||||||
|
|
||||||
results: list = []
|
results: list = []
|
||||||
@@ -493,7 +494,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
# No mask found
|
# No mask found
|
||||||
if mask_probs.shape[0] <= 0:
|
if mask_probs.shape[0] <= 0:
|
||||||
height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
|
height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
|
||||||
segmentation = torch.zeros((height, width)) - 1
|
segmentation = torch.zeros((height, width)) - 1
|
||||||
results.append({"segmentation": segmentation, "segments_info": []})
|
results.append({"segmentation": segmentation, "segments_info": []})
|
||||||
continue
|
continue
|
||||||
@@ -505,16 +506,17 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
stuff_classes=stuff_classes,
|
stuff_classes=stuff_classes,
|
||||||
mask_threshold=mask_threshold,
|
mask_threshold=mask_threshold,
|
||||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||||
target_size=original_image_sizes[i] if original_image_sizes is not None else None,
|
target_size=target_sizes[i] if target_sizes is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
def post_process_instance_segmentation(
|
def post_process_instance_segmentation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
original_image_sizes: list[tuple[int, int]],
|
target_sizes: list[tuple[int, int]],
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
size: Optional[dict[str, int]] = None,
|
size: Optional[dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
@@ -532,7 +534,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
|
mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
|
||||||
|
|
||||||
device = masks_queries_logits.device
|
device = masks_queries_logits.device
|
||||||
batch_size = class_queries_logits.shape[0]
|
batch_size = class_queries_logits.shape[0]
|
||||||
@@ -554,7 +556,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|||||||
)
|
)
|
||||||
pred_scores = scores * mask_scores
|
pred_scores = scores * mask_scores
|
||||||
|
|
||||||
segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
|
segmentation = torch.zeros(target_sizes[i], device=device) - 1
|
||||||
|
|
||||||
instance_maps, segments = [], []
|
instance_maps, segments = [], []
|
||||||
current_segment_id = 0
|
current_segment_id = 0
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
|||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
||||||
|
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||||
|
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -82,6 +84,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
|||||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||||
|
patch_offsets: Optional[list[torch.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
|
||||||
@@ -996,7 +999,7 @@ class EomtPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "eomt"
|
base_model_prefix = "eomt"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = False
|
||||||
_no_split_modules = ["EomtMLP"]
|
_no_split_modules = ["EomtLayer"]
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
@@ -1097,13 +1100,16 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
|
|||||||
class_labels: Optional[list[Tensor]] = None,
|
class_labels: Optional[list[Tensor]] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
|
patch_offsets: Optional[list[Tensor]] = None,
|
||||||
) -> EomtForUniversalSegmentationOutput:
|
) -> EomtForUniversalSegmentationOutput:
|
||||||
r"""
|
r"""
|
||||||
mask_labels (`List[torch.Tensor]`, *optional*):
|
mask_labels (`list[torch.Tensor]`, *optional*):
|
||||||
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
list of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||||
class_labels (`List[torch.LongTensor]`, *optional*):
|
class_labels (`list[torch.LongTensor]`, *optional*):
|
||||||
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
||||||
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
||||||
|
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||||
|
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||||
"""
|
"""
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -1126,7 +1132,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
|
|||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if idx == self.num_hidden_layers - self.config.num_blocks:
|
if idx == self.num_hidden_layers - self.config.num_blocks:
|
||||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
|
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
|
||||||
hidden_states = torch.cat((query, hidden_states), dim=1)
|
hidden_states = torch.cat((query, hidden_states), dim=1)
|
||||||
|
|
||||||
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
||||||
@@ -1206,6 +1212,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
|
|||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_attentions,
|
attentions=all_attentions,
|
||||||
|
patch_offsets=patch_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
|
|||||||
@@ -226,6 +226,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
|||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
|
||||||
|
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||||
|
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -234,6 +236,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
|
|||||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||||
|
patch_offsets: Optional[list[torch.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class EomtLoss(Mask2FormerLoss):
|
class EomtLoss(Mask2FormerLoss):
|
||||||
@@ -368,7 +371,7 @@ class EomtPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "eomt"
|
base_model_prefix = "eomt"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = False
|
||||||
_no_split_modules = ["EomtMLP"]
|
_no_split_modules = ["EomtLayer"]
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
@@ -473,13 +476,16 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
|
|||||||
class_labels: Optional[list[Tensor]] = None,
|
class_labels: Optional[list[Tensor]] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
|
patch_offsets: Optional[list[Tensor]] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
mask_labels (`List[torch.Tensor]`, *optional*):
|
mask_labels (`list[torch.Tensor]`, *optional*):
|
||||||
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
list of mask labels of shape `(num_labels, height, width)` to be fed to a model
|
||||||
class_labels (`List[torch.LongTensor]`, *optional*):
|
class_labels (`list[torch.LongTensor]`, *optional*):
|
||||||
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
|
||||||
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
|
||||||
|
patch_offsets (`list[torch.Tensor]`, *optional*):
|
||||||
|
list of tuples indicating the image index and start and end positions of patches for semantic segementation.
|
||||||
"""
|
"""
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -502,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
|
|||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if idx == self.num_hidden_layers - self.config.num_blocks:
|
if idx == self.num_hidden_layers - self.config.num_blocks:
|
||||||
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
|
query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
|
||||||
hidden_states = torch.cat((query, hidden_states), dim=1)
|
hidden_states = torch.cat((query, hidden_states), dim=1)
|
||||||
|
|
||||||
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
if idx >= self.num_hidden_layers - self.config.num_blocks and (
|
||||||
@@ -582,6 +588,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
|
|||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_attentions,
|
attentions=all_attentions,
|
||||||
|
patch_offsets=patch_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -84,10 +84,11 @@ class EomtImageProcessingTester:
|
|||||||
"num_labels": self.num_labels,
|
"num_labels": self.num_labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_fake_eomt_outputs(self, batch_size):
|
def prepare_fake_eomt_outputs(self, batch_size, patch_offsets=None):
|
||||||
return EomtForUniversalSegmentationOutput(
|
return EomtForUniversalSegmentationOutput(
|
||||||
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
|
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
|
||||||
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
|
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
|
||||||
|
patch_offsets=patch_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||||
@@ -263,13 +264,13 @@ class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||||
|
|
||||||
inputs = processor(images=image, do_split_image=True, return_tensors="pt")
|
inputs = processor(images=image, do_split_image=True, return_tensors="pt")
|
||||||
patch_offsets = inputs.pop("patch_offsets")
|
patch_offsets = inputs["patch_offsets"]
|
||||||
|
|
||||||
original_sizes = [image.size[::-1]]
|
target_sizes = [image.size[::-1]]
|
||||||
|
|
||||||
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
|
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
|
||||||
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0])
|
outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0], patch_offsets)
|
||||||
segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes)
|
segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes)
|
||||||
|
|
||||||
self.assertEqual(segmentation[0].shape, (image.height, image.width))
|
self.assertEqual(segmentation[0].shape, (image.height, image.width))
|
||||||
|
|
||||||
|
|||||||
@@ -17,12 +17,13 @@ import unittest
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation
|
from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation, pipeline
|
||||||
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -100,8 +101,9 @@ class EomtForUniversalSegmentationTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase):
|
class EomtForUniversalSegmentationTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
|
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = {"image-segmentation": EomtForUniversalSegmentation} if is_torch_available() else {}
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
@@ -340,7 +342,6 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
|||||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||||
|
|
||||||
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
inputs = processor(images=image, return_tensors="pt").to(model.device)
|
||||||
patch_offsets = inputs.pop("patch_offsets", None)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
@@ -348,11 +349,9 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
|
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
|
||||||
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
|
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
|
||||||
|
|
||||||
preds = processor.post_process_semantic_segmentation(
|
preds = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
|
||||||
outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0]))
|
self.assertTrue(preds.shape == (image.size[1], image.size[0]))
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_SLICE = torch.tensor([
|
EXPECTED_SLICE = torch.tensor([
|
||||||
@@ -369,7 +368,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
|||||||
], device=model.device)
|
], device=model.device)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
output_slice = preds[0, :10, :10]
|
output_slice = preds[:10, :10]
|
||||||
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -387,9 +386,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
|
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
|
||||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||||
|
|
||||||
preds = processor.post_process_panoptic_segmentation(
|
preds = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
|
||||||
outputs, original_image_sizes=[(image.size[1], image.size[0])]
|
|
||||||
)[0]
|
|
||||||
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -438,9 +435,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
|
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
|
||||||
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
|
||||||
|
|
||||||
preds = processor.post_process_instance_segmentation(
|
preds = processor.post_process_instance_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
|
||||||
outputs, original_image_sizes=[(image.size[1], image.size[0])]
|
|
||||||
)[0]
|
|
||||||
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -473,3 +468,15 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(actual["id"], expected["id"])
|
self.assertEqual(actual["id"], expected["id"])
|
||||||
self.assertEqual(actual["label_id"], expected["label_id"])
|
self.assertEqual(actual["label_id"], expected["label_id"])
|
||||||
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
|
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_segmentation_pipeline(self):
|
||||||
|
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||||
|
|
||||||
|
pipe = pipeline(model=self.model_id, subtask="panoptic", device=torch_device)
|
||||||
|
output = pipe(image)
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT_LABELS = ["cat", "cat", "couch", "remote", "remote"]
|
||||||
|
|
||||||
|
output_labels = [segment["label"] for segment in output]
|
||||||
|
self.assertEqual(output_labels, EXPECTED_OUTPUT_LABELS)
|
||||||
|
|||||||
Reference in New Issue
Block a user