diff --git a/docs/source/en/model_doc/owlvit.mdx b/docs/source/en/model_doc/owlvit.mdx index b95efd59b5..f13ad4a540 100644 --- a/docs/source/en/model_doc/owlvit.mdx +++ b/docs/source/en/model_doc/owlvit.mdx @@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi [[autodoc]] OwlViTImageProcessor - preprocess - - post_process + - post_process_object_detection - post_process_image_guided_detection ## OwlViTFeatureExtractor diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index b07e164754..fc3f0fa331 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -14,7 +14,8 @@ # limitations under the License. """Image processor class for OwlViT""" -from typing import Dict, List, Optional, Union +import warnings +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor): in the batch as predicted by the model. """ # TODO: (amy) add support for other frameworks + warnings.warn( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection`", + FutureWarning, + ) + logits, boxes = outputs.logits, outputs.pred_boxes if len(logits) != len(target_sizes): @@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor): return results + def post_process_object_detection( + self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + # TODO: (amy) add support for other frameworks + logits, boxes = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, List): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + # TODO: (Amy) Make compatible with other frameworks def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None): """ diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index fd0f30b79d..39f1334834 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding - possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized - bounding boxes. + possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): @@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput): target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual target image in the batch - (disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the - unnormalized bounding boxes. + (disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual query image in the batch - (disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the - unnormalized bounding boxes. + (disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes image embeddings for each patch. @@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] >>> target_sizes = torch.Tensor([image.size[::-1]]) - >>> # Convert outputs (bounding boxes and class logits) to COCO API - >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes) + >>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores + >>> results = processor.post_process_object_detection( + ... outputs=outputs, threshold=0.1, target_sizes=target_sizes + ... ) >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries >>> text = texts[i] >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] - >>> score_threshold = 0.1 >>> for box, score, label in zip(boxes, scores, labels): ... box = [round(i, 2) for i in box.tolist()] - ... if score >= score_threshold: - ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") + ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] ```""" diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 73fce77dd3..04b8c191ac 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin): """ return self.image_processor.post_process(*args, **kwargs) + def post_process_object_detection(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer + to the docstring of this method for more information. + """ + return self.image_processor.post_process_object_detection(*args, **kwargs) + def post_process_image_guided_detection(self, *args, **kwargs): """ This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`]. diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 9a978bc929..7f8c46c0d7 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): for model_output in model_outputs: label = model_output["candidate_label"] model_output = BaseModelOutput(model_output) - outputs = self.feature_extractor.post_process( - outputs=model_output, target_sizes=model_output["target_size"] + outputs = self.feature_extractor.post_process_object_detection( + outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"] )[0] - keep = outputs["scores"] >= threshold - for index in keep.nonzero(): + for index in outputs["scores"].nonzero(): score = outputs["scores"][index].item() box = self._get_bounding_box(outputs["boxes"][index][0]) diff --git a/tests/pipelines/test_pipelines_zero_shot_object_detection.py b/tests/pipelines/test_pipelines_zero_shot_object_detection.py index 06a611b53a..c48b8c381d 100644 --- a/tests/pipelines/test_pipelines_zero_shot_object_detection.py +++ b/tests/pipelines/test_pipelines_zero_shot_object_detection.py @@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline object_detector = pipeline("zero-shot-object-detection") outputs = object_detector( - "http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"] + "http://images.cocodataset.org/val2017/000000039769.jpg", + candidate_labels=["cat", "remote", "couch"], ) self.assertEqual( nested_simplify(outputs, decimals=4),