From b61023a1b760b207d99b699dafc1fbfde992c12c Mon Sep 17 00:00:00 2001 From: Yaswanth Gali <82788246+yaswanth19@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:55:26 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20[eomt]?= =?UTF-8?q?=20make=20EoMT=20compatible=20with=20pipeline=20(#39122)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make EoMT compatible with pipeline * Implicit patch offsets * remove patch offsets from arg * Modify tests * Update example * fix proc testcase * Add few more args * add pipeline test suite * fix * docstring fixes * add pipeline test * changes w.r.t review * 🙈 MB * should fix device mismatch * debug * Fixes device mismatch * use decorator * we can split mlp * expected values update --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- docs/source/en/model_doc/eomt.md | 16 ++--- .../models/eomt/image_processing_eomt.py | 61 ++++++++++--------- .../models/eomt/image_processing_eomt_fast.py | 48 ++++++++------- src/transformers/models/eomt/modeling_eomt.py | 17 ++++-- src/transformers/models/eomt/modular_eomt.py | 17 ++++-- .../models/eomt/test_image_processing_eomt.py | 11 ++-- tests/models/eomt/test_modeling_eomt.py | 35 ++++++----- 7 files changed, 113 insertions(+), 92 deletions(-) diff --git a/docs/source/en/model_doc/eomt.md b/docs/source/en/model_doc/eomt.md index 34842de210..86816a475f 100644 --- a/docs/source/en/model_doc/eomt.md +++ b/docs/source/en/model_doc/eomt.md @@ -74,20 +74,16 @@ inputs = processor( return_tensors="pt", ) -# Remove Patch Offsets from inputs — only used later for post-processing. -patch_offsets = inputs.pop("patch_offsets") - with torch.inference_mode(): outputs = model(**inputs) # Prepare the original image size in the format (height, width) -original_image_sizes = [(image.height, image.width)] +target_sizes = [(image.height, image.width)] # Post-process the model outputs to get final segmentation prediction preds = processor.post_process_semantic_segmentation( outputs, - patch_offsets=patch_offsets, - original_image_sizes=original_image_sizes, + target_sizes=target_sizes, ) # Visualize the segmentation mask @@ -130,12 +126,12 @@ with torch.inference_mode(): outputs = model(**inputs) # Prepare the original image size in the format (height, width) -original_image_sizes = [(image.height, image.width)] +target_sizes = [(image.height, image.width)] # Post-process the model outputs to get final segmentation prediction preds = processor.post_process_instance_segmentation( outputs, - original_image_sizes=original_image_sizes, + target_sizes=target_sizes, ) # Visualize the segmentation mask @@ -173,12 +169,12 @@ with torch.inference_mode(): outputs = model(**inputs) # Prepare the original image size in the format (height, width) -original_image_sizes = [(image.height, image.width)] +target_sizes = [(image.height, image.width)] # Post-process the model outputs to get final segmentation prediction preds = processor.post_process_panoptic_segmentation( outputs, - original_image_sizes=original_image_sizes, + target_sizes=target_sizes, ) # Visualize the panoptic segmentation mask diff --git a/src/transformers/models/eomt/image_processing_eomt.py b/src/transformers/models/eomt/image_processing_eomt.py index 73fe46034c..e63a1be95f 100644 --- a/src/transformers/models/eomt/image_processing_eomt.py +++ b/src/transformers/models/eomt/image_processing_eomt.py @@ -97,7 +97,7 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, in Computes the output image size given the input image size and the desired output size. Args: - image_size (`Tuple[int, int]`): + image_size (`tuple[int, int]`): The input image size. size (`int`): The desired output size. @@ -531,13 +531,13 @@ class EomtImageProcessor(BaseImageProcessor): Image or batch of images to preprocess. segmentation_maps (`ImageInput`, *optional*): The corresponding semantic segmentation maps with the pixel-wise annotations. - instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): A mapping between object instance ids and class ids. do_split_image (`bool`, *optional*, defaults to `self.do_split_image`): Whether to split the input images into overlapping patches for semantic segmentation. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the input images. - size (`Dict[str, int]`, *optional*, defaults to `self.size`): + size (`dict[str, int]`, *optional*, defaults to `self.size`): Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): Resampling filter to use when resizing. @@ -550,9 +550,9 @@ class EomtImageProcessor(BaseImageProcessor): do_pad (`bool`, *optional*, defaults to `False`): Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest number of patches in the batch. Padding will be applied to the bottom and right with zeros. - image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): Mean for normalization. Single value or list for each channel. - image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): Standard deviation for normalization. Single value or list for each channel. ignore_index (`int`, *optional*): Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels @@ -640,7 +640,7 @@ class EomtImageProcessor(BaseImageProcessor): ) if do_split_image and patch_offsets: - encoded_inputs["patch_offsets"] = patch_offsets + encoded_inputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] return encoded_inputs @@ -663,8 +663,8 @@ class EomtImageProcessor(BaseImageProcessor): each mask. Args: - pixel_values_list (`List[ImageInput]`): - List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + pixel_values_list (`list[ImageInput]`): + list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, width)`. segmentation_maps (`ImageInput`, *optional*): @@ -678,7 +678,7 @@ class EomtImageProcessor(BaseImageProcessor): - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an instance segmentation map where each pixel represents an instance id. Can be provided as a single dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map @@ -740,7 +740,7 @@ class EomtImageProcessor(BaseImageProcessor): self, segmentation_logits: torch.Tensor, patch_offsets: list[tuple[int, int, int]], - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], size: dict[str, int], ) -> list[torch.Tensor]: """ @@ -750,28 +750,28 @@ class EomtImageProcessor(BaseImageProcessor): segmentation_logits (`torch.Tensor`): A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits for each image patch. - patch_offsets (`List[Tuple[int, int, int]]`): + patch_offsets (`list[tuple[int, int, int]]`): A list of tuples where each tuple contains: - `image_index` (int): Index of the original image this patch belongs to. - `start` (int): Start pixel index of the patch along the long dimension (height or width). - `end` (int): End pixel index of the patch along the long dimension. - original_image_sizes (`List[Tuple[int, int]]`): - List of original (height, width) dimensions for each image before preprocessing. - size (`Dict[str, int]`): + target_sizes (`list[tuple[int, int]]`): + list of original (height, width) dimensions for each image before preprocessing. + size (`dict[str, int]`): A size dict which was used to resize. """ num_classes = segmentation_logits.shape[1] aggregated_logits = [] patch_counts = [] - for image_size in original_image_sizes: + for image_size in target_sizes: height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) # Stitch patches back into full-sized logit maps for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): - if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]: + if target_sizes[image_idx][0] > target_sizes[image_idx][1]: aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] patch_counts[image_idx][:, patch_start:patch_end, :] += 1 else: @@ -784,7 +784,7 @@ class EomtImageProcessor(BaseImageProcessor): averaged_logits = logit_sum / count.clamp(min=1) resized_logits = F.interpolate( averaged_logits[None, ...], - size=original_image_sizes[idx], + size=target_sizes[idx], mode="bilinear", align_corners=False, )[0] @@ -796,14 +796,14 @@ class EomtImageProcessor(BaseImageProcessor): def unpad_image( self, segmentation_logits: torch.Tensor, - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], size: dict[str, int], ) -> list[torch.Tensor]: """Restores panoptic segmentation logits to their original image resolutions.""" resized_logits = [] - for idx, original_size in enumerate(original_image_sizes): + for idx, original_size in enumerate(target_sizes): target_height, target_width = get_size_with_aspect_ratio( original_size, size["shortest_edge"], size["longest_edge"] ) @@ -817,8 +817,7 @@ class EomtImageProcessor(BaseImageProcessor): def post_process_semantic_segmentation( self, outputs, - patch_offsets: list[tuple[int, int, int]], - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], size: Optional[dict[str, int]] = None, ) -> np.ndarray: """Post-processes model outputs into final semantic segmentation prediction.""" @@ -827,6 +826,7 @@ class EomtImageProcessor(BaseImageProcessor): masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + patch_offsets = outputs.patch_offsets output_size = get_target_size(size) masks_queries_logits = F.interpolate( @@ -841,15 +841,15 @@ class EomtImageProcessor(BaseImageProcessor): segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) - output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size) + output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size) - preds = torch.stack(output_logits).argmax(dim=1) + preds = [logit.argmax(dim=0) for logit in output_logits] return preds def post_process_panoptic_segmentation( self, outputs, - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], threshold: float = 0.8, mask_threshold: float = 0.5, overlap_mask_area_threshold: float = 0.8, @@ -873,7 +873,7 @@ class EomtImageProcessor(BaseImageProcessor): mode="bilinear", ) - mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1) results: list = [] @@ -885,7 +885,7 @@ class EomtImageProcessor(BaseImageProcessor): # No mask found if mask_probs.shape[0] <= 0: - height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:] + height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:] segmentation = torch.zeros((height, width)) - 1 results.append({"segmentation": segmentation, "segments_info": []}) continue @@ -897,16 +897,17 @@ class EomtImageProcessor(BaseImageProcessor): stuff_classes=stuff_classes, mask_threshold=mask_threshold, overlap_mask_area_threshold=overlap_mask_area_threshold, - target_size=original_image_sizes[i] if original_image_sizes is not None else None, + target_size=target_sizes[i] if target_sizes is not None else None, ) results.append({"segmentation": segmentation, "segments_info": segments}) return results + @filter_out_non_signature_kwargs() def post_process_instance_segmentation( self, outputs, - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], threshold: float = 0.5, size: Optional[dict[str, int]] = None, ): @@ -924,7 +925,7 @@ class EomtImageProcessor(BaseImageProcessor): mode="bilinear", ) - mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) device = masks_queries_logits.device batch_size = class_queries_logits.shape[0] @@ -946,7 +947,7 @@ class EomtImageProcessor(BaseImageProcessor): ) pred_scores = scores * mask_scores - segmentation = torch.zeros(original_image_sizes[i], device=device) - 1 + segmentation = torch.zeros(target_sizes[i], device=device) - 1 instance_maps, segments = [], [] current_segment_id = 0 diff --git a/src/transformers/models/eomt/image_processing_eomt_fast.py b/src/transformers/models/eomt/image_processing_eomt_fast.py index 04b53c418d..343c6ae2cf 100644 --- a/src/transformers/models/eomt/image_processing_eomt_fast.py +++ b/src/transformers/models/eomt/image_processing_eomt_fast.py @@ -41,6 +41,7 @@ from ...processing_utils import Unpack from ...utils import ( TensorType, auto_docstring, + filter_out_non_signature_kwargs, is_torch_available, is_torchvision_available, is_torchvision_v2_available, @@ -268,7 +269,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): r""" segmentation_maps (`ImageInput`, *optional*): The segmentation maps to preprocess for corresponding images. - instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): A mapping between object instance ids and class ids. """ # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same @@ -340,7 +341,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): outputs["class_labels"] = class_labels if patch_offsets: - outputs["patch_offsets"] = patch_offsets + outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] return outputs @@ -348,7 +349,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): self, segmentation_logits: torch.Tensor, patch_offsets: list[tuple[int, int, int]], - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], size: dict[str, int], ) -> list[torch.Tensor]: """ @@ -358,28 +359,28 @@ class EomtImageProcessorFast(BaseImageProcessorFast): segmentation_logits (`torch.Tensor`): A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits for each image patch. - patch_offsets (`List[Tuple[int, int, int]]`): + patch_offsets (`list[tuple[int, int, int]]`): A list of tuples where each tuple contains: - `image_index` (int): Index of the original image this patch belongs to. - `start` (int): Start pixel index of the patch along the long dimension (height or width). - `end` (int): End pixel index of the patch along the long dimension. - original_image_sizes (`List[Tuple[int, int]]`): - List of original (height, width) dimensions for each image before preprocessing. - size (`Dict[str, int]`): + target_sizes (`list[tuple[int, int]]`): + list of original (height, width) dimensions for each image before preprocessing. + size (`dict[str, int]`): A size dict which was used to resize. """ num_classes = segmentation_logits.shape[1] aggregated_logits = [] patch_counts = [] - for image_size in original_image_sizes: + for image_size in target_sizes: height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device)) # Stitch patches back into full-sized logit maps for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): - if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]: + if target_sizes[image_idx][0] > target_sizes[image_idx][1]: aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] patch_counts[image_idx][:, patch_start:patch_end, :] += 1 else: @@ -392,7 +393,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): averaged_logits = logit_sum / count.clamp(min=1) resized_logits = torch.nn.functional.interpolate( averaged_logits[None, ...], - size=original_image_sizes[idx], + size=target_sizes[idx], mode="bilinear", align_corners=False, )[0] @@ -404,14 +405,14 @@ class EomtImageProcessorFast(BaseImageProcessorFast): def unpad_image( self, segmentation_logits: torch.Tensor, - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], size: dict[str, int], ) -> list[torch.Tensor]: """Restores panoptic segmentation logits to their original image resolutions.""" resized_logits = [] - for idx, original_size in enumerate(original_image_sizes): + for idx, original_size in enumerate(target_sizes): target_height, target_width = get_size_with_aspect_ratio( original_size, size["shortest_edge"], size["longest_edge"] ) @@ -425,8 +426,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): def post_process_semantic_segmentation( self, outputs, - patch_offsets: list[tuple[int, int, int]], - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], size: Optional[dict[str, int]] = None, ) -> np.ndarray: """Post-processes model outputs into final semantic segmentation prediction.""" @@ -435,6 +435,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + patch_offsets = outputs.patch_offsets output_size = get_target_size(size) masks_queries_logits = torch.nn.functional.interpolate( @@ -449,15 +450,15 @@ class EomtImageProcessorFast(BaseImageProcessorFast): segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) - output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size) + output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size) - preds = torch.stack(output_logits).argmax(dim=1) + preds = [logit.argmax(dim=0) for logit in output_logits] return preds def post_process_panoptic_segmentation( self, outputs, - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], threshold: float = 0.8, mask_threshold: float = 0.5, overlap_mask_area_threshold: float = 0.8, @@ -481,7 +482,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): mode="bilinear", ) - mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1) results: list = [] @@ -493,7 +494,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): # No mask found if mask_probs.shape[0] <= 0: - height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:] + height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:] segmentation = torch.zeros((height, width)) - 1 results.append({"segmentation": segmentation, "segments_info": []}) continue @@ -505,16 +506,17 @@ class EomtImageProcessorFast(BaseImageProcessorFast): stuff_classes=stuff_classes, mask_threshold=mask_threshold, overlap_mask_area_threshold=overlap_mask_area_threshold, - target_size=original_image_sizes[i] if original_image_sizes is not None else None, + target_size=target_sizes[i] if target_sizes is not None else None, ) results.append({"segmentation": segmentation, "segments_info": segments}) return results + @filter_out_non_signature_kwargs() def post_process_instance_segmentation( self, outputs, - original_image_sizes: list[tuple[int, int]], + target_sizes: list[tuple[int, int]], threshold: float = 0.8, size: Optional[dict[str, int]] = None, ): @@ -532,7 +534,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): mode="bilinear", ) - mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size) + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) device = masks_queries_logits.device batch_size = class_queries_logits.shape[0] @@ -554,7 +556,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): ) pred_scores = scores * mask_scores - segmentation = torch.zeros(original_image_sizes[i], device=device) - 1 + segmentation = torch.zeros(target_sizes[i], device=device) - 1 instance_maps, segments = [], [] current_segment_id = 0 diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index bbdd11e1f5..bc865988ca 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -74,6 +74,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Self and Cross Attentions weights from transformer decoder. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segementation. """ loss: Optional[torch.FloatTensor] = None @@ -82,6 +84,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput): last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None + patch_offsets: Optional[list[torch.Tensor]] = None # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py @@ -996,7 +999,7 @@ class EomtPreTrainedModel(PreTrainedModel): base_model_prefix = "eomt" main_input_name = "pixel_values" supports_gradient_checkpointing = False - _no_split_modules = ["EomtMLP"] + _no_split_modules = ["EomtLayer"] _supports_sdpa = True _supports_flash_attn_2 = True @@ -1097,13 +1100,16 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel): class_labels: Optional[list[Tensor]] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, + patch_offsets: Optional[list[Tensor]] = None, ) -> EomtForUniversalSegmentationOutput: r""" - mask_labels (`List[torch.Tensor]`, *optional*): - List of mask labels of shape `(num_labels, height, width)` to be fed to a model - class_labels (`List[torch.LongTensor]`, *optional*): + mask_labels (`list[torch.Tensor]`, *optional*): + list of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`list[torch.LongTensor]`, *optional*): list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segementation. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1126,7 +1132,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel): all_hidden_states += (hidden_states,) if idx == self.num_hidden_layers - self.config.num_blocks: - query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1) + query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) hidden_states = torch.cat((query, hidden_states), dim=1) if idx >= self.num_hidden_layers - self.config.num_blocks and ( @@ -1206,6 +1212,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel): last_hidden_state=sequence_output, hidden_states=all_hidden_states, attentions=all_attentions, + patch_offsets=patch_offsets, ) def get_input_embeddings(self): diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index fc82836e4b..44ecb69eca 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -226,6 +226,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Self and Cross Attentions weights from transformer decoder. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segementation. """ loss: Optional[torch.FloatTensor] = None @@ -234,6 +236,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput): last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None + patch_offsets: Optional[list[torch.Tensor]] = None class EomtLoss(Mask2FormerLoss): @@ -368,7 +371,7 @@ class EomtPreTrainedModel(PreTrainedModel): base_model_prefix = "eomt" main_input_name = "pixel_values" supports_gradient_checkpointing = False - _no_split_modules = ["EomtMLP"] + _no_split_modules = ["EomtLayer"] _supports_sdpa = True _supports_flash_attn_2 = True @@ -473,13 +476,16 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul class_labels: Optional[list[Tensor]] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, + patch_offsets: Optional[list[Tensor]] = None, ): r""" - mask_labels (`List[torch.Tensor]`, *optional*): - List of mask labels of shape `(num_labels, height, width)` to be fed to a model - class_labels (`List[torch.LongTensor]`, *optional*): + mask_labels (`list[torch.Tensor]`, *optional*): + list of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`list[torch.LongTensor]`, *optional*): list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segementation. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -502,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul all_hidden_states += (hidden_states,) if idx == self.num_hidden_layers - self.config.num_blocks: - query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1) + query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) hidden_states = torch.cat((query, hidden_states), dim=1) if idx >= self.num_hidden_layers - self.config.num_blocks and ( @@ -582,6 +588,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul last_hidden_state=sequence_output, hidden_states=all_hidden_states, attentions=all_attentions, + patch_offsets=patch_offsets, ) diff --git a/tests/models/eomt/test_image_processing_eomt.py b/tests/models/eomt/test_image_processing_eomt.py index 6d449453de..594a1d9fe8 100644 --- a/tests/models/eomt/test_image_processing_eomt.py +++ b/tests/models/eomt/test_image_processing_eomt.py @@ -84,10 +84,11 @@ class EomtImageProcessingTester: "num_labels": self.num_labels, } - def prepare_fake_eomt_outputs(self, batch_size): + def prepare_fake_eomt_outputs(self, batch_size, patch_offsets=None): return EomtForUniversalSegmentationOutput( masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)), class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)), + patch_offsets=patch_offsets, ) def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): @@ -263,13 +264,13 @@ class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) inputs = processor(images=image, do_split_image=True, return_tensors="pt") - patch_offsets = inputs.pop("patch_offsets") + patch_offsets = inputs["patch_offsets"] - original_sizes = [image.size[::-1]] + target_sizes = [image.size[::-1]] # For semantic segmentation, the BS of output is 2 coz, two patches are created for the image. - outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0]) - segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes) + outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0], patch_offsets) + segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes) self.assertEqual(segmentation[0].shape, (image.height, image.width)) diff --git a/tests/models/eomt/test_modeling_eomt.py b/tests/models/eomt/test_modeling_eomt.py index c526030250..c4b026cc18 100644 --- a/tests/models/eomt/test_modeling_eomt.py +++ b/tests/models/eomt/test_modeling_eomt.py @@ -17,12 +17,13 @@ import unittest import requests -from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation +from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation, pipeline from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -100,8 +101,9 @@ class EomtForUniversalSegmentationTester: @require_torch -class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase): +class EomtForUniversalSegmentationTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else () + pipeline_model_mapping = {"image-segmentation": EomtForUniversalSegmentation} if is_torch_available() else {} is_encoder_decoder = False test_pruning = False test_head_masking = False @@ -340,7 +342,6 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) inputs = processor(images=image, return_tensors="pt").to(model.device) - patch_offsets = inputs.pop("patch_offsets", None) with torch.inference_mode(): outputs = model(**inputs) @@ -348,11 +349,9 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151)) self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128)) - preds = processor.post_process_semantic_segmentation( - outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets - ) + preds = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0] - self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0])) + self.assertTrue(preds.shape == (image.size[1], image.size[0])) # fmt: off EXPECTED_SLICE = torch.tensor([ @@ -369,7 +368,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): ], device=model.device) # fmt: on - output_slice = preds[0, :10, :10] + output_slice = preds[:10, :10] torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2) @slow @@ -387,9 +386,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134)) self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160)) - preds = processor.post_process_panoptic_segmentation( - outputs, original_image_sizes=[(image.size[1], image.size[0])] - )[0] + preds = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0] segmentation, segments_info = preds["segmentation"], preds["segments_info"] # fmt: off @@ -438,9 +435,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81)) self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160)) - preds = processor.post_process_instance_segmentation( - outputs, original_image_sizes=[(image.size[1], image.size[0])] - )[0] + preds = processor.post_process_instance_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0] segmentation, segments_info = preds["segmentation"], preds["segments_info"] # fmt: off @@ -473,3 +468,15 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): self.assertEqual(actual["id"], expected["id"]) self.assertEqual(actual["label_id"], expected["label_id"]) self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3) + + @slow + def test_segmentation_pipeline(self): + image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + pipe = pipeline(model=self.model_id, subtask="panoptic", device=torch_device) + output = pipe(image) + + EXPECTED_OUTPUT_LABELS = ["cat", "cat", "couch", "remote", "remote"] + + output_labels = [segment["label"] for segment in output] + self.assertEqual(output_labels, EXPECTED_OUTPUT_LABELS)