From ea9caf7abaf5be1dd7d6167c14662fb8c7ac1f53 Mon Sep 17 00:00:00 2001 From: Rafael Padilla <31217453+rafaelpadilla@users.noreply.github.com> Date: Tue, 4 Jul 2023 16:47:57 -0300 Subject: [PATCH] Update warning messages reffering to post_process_object_detection (#24649) * including the threshold alert in warning messages. * Updating doc owlvit.md including post_process_object_detection function with threshold. * fix --- docs/source/en/model_doc/owlvit.md | 8 ++------ .../conditional_detr/image_processing_conditional_detr.py | 2 +- .../deformable_detr/image_processing_deformable_detr.py | 2 +- src/transformers/models/detr/image_processing_detr.py | 2 +- src/transformers/models/owlvit/image_processing_owlvit.py | 2 +- src/transformers/models/yolos/image_processing_yolos.py | 2 +- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/docs/source/en/model_doc/owlvit.md b/docs/source/en/model_doc/owlvit.md index 737e42e253..b18b80b405 100644 --- a/docs/source/en/model_doc/owlvit.md +++ b/docs/source/en/model_doc/owlvit.md @@ -50,17 +50,13 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] >>> target_sizes = torch.Tensor([image.size[::-1]]) >>> # Convert outputs (bounding boxes and class logits) to COCO API ->>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes) - +>>> results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1) >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries >>> text = texts[i] >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] - ->>> score_threshold = 0.1 >>> for box, score, label in zip(boxes, scores, labels): ... box = [round(i, 2) for i in box.tolist()] -... if score >= score_threshold: -... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") +... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] ``` diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index 3de243cd86..390713c0a5 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -1250,7 +1250,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): """ logging.warning_once( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_object_detection`", + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", ) out_logits, out_bbox = outputs.logits, outputs.pred_boxes diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index 6aa3d5a82f..ae756c0a2b 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -1248,7 +1248,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor): """ logger.warning_once( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_object_detection`.", + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", ) out_logits, out_bbox = outputs.logits, outputs.pred_boxes diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index 5983506eff..af14c38dcc 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -1219,7 +1219,7 @@ class DetrImageProcessor(BaseImageProcessor): """ logger.warning_once( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_object_detection`", + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", ) out_logits, out_bbox = outputs.logits, outputs.pred_boxes diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index ea5cd7b776..9268aac15b 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -354,7 +354,7 @@ class OwlViTImageProcessor(BaseImageProcessor): # TODO: (amy) add support for other frameworks warnings.warn( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_object_detection`", + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", FutureWarning, ) diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index a472674171..372ab2b3ce 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -1151,7 +1151,7 @@ class YolosImageProcessor(BaseImageProcessor): """ logger.warning_once( "`post_process` is deprecated and will be removed in v5 of Transformers, please use" - " `post_process_object_detection`", + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", ) out_logits, out_bbox = outputs.logits, outputs.pred_boxes