Improve OWL-ViT postprocessing (#20980)
* add post_process_object_detection method * style changes
This commit is contained in:
@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
|
|||||||
|
|
||||||
[[autodoc]] OwlViTImageProcessor
|
[[autodoc]] OwlViTImageProcessor
|
||||||
- preprocess
|
- preprocess
|
||||||
- post_process
|
- post_process_object_detection
|
||||||
- post_process_image_guided_detection
|
- post_process_image_guided_detection
|
||||||
|
|
||||||
## OwlViTFeatureExtractor
|
## OwlViTFeatureExtractor
|
||||||
|
|||||||
@@ -14,7 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for OwlViT"""
|
"""Image processor class for OwlViT"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Union
|
import warnings
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
|||||||
in the batch as predicted by the model.
|
in the batch as predicted by the model.
|
||||||
"""
|
"""
|
||||||
# TODO: (amy) add support for other frameworks
|
# TODO: (amy) add support for other frameworks
|
||||||
|
warnings.warn(
|
||||||
|
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
||||||
|
" `post_process_object_detection`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
logits, boxes = outputs.logits, outputs.pred_boxes
|
logits, boxes = outputs.logits, outputs.pred_boxes
|
||||||
|
|
||||||
if len(logits) != len(target_sizes):
|
if len(logits) != len(target_sizes):
|
||||||
@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def post_process_object_detection(
|
||||||
|
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
||||||
|
bottom_right_x, bottom_right_y) format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`OwlViTObjectDetectionOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
threshold (`float`, *optional*):
|
||||||
|
Score threshold to keep object detection predictions.
|
||||||
|
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||||
|
in the batch as predicted by the model.
|
||||||
|
"""
|
||||||
|
# TODO: (amy) add support for other frameworks
|
||||||
|
logits, boxes = outputs.logits, outputs.pred_boxes
|
||||||
|
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
probs = torch.max(logits, dim=-1)
|
||||||
|
scores = torch.sigmoid(probs.values)
|
||||||
|
labels = probs.indices
|
||||||
|
|
||||||
|
# Convert to [x0, y0, x1, y1] format
|
||||||
|
boxes = center_to_corners_format(boxes)
|
||||||
|
|
||||||
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||||
|
if target_sizes is not None:
|
||||||
|
if isinstance(target_sizes, List):
|
||||||
|
img_h = torch.Tensor([i[0] for i in target_sizes])
|
||||||
|
img_w = torch.Tensor([i[1] for i in target_sizes])
|
||||||
|
else:
|
||||||
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
|
|
||||||
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||||
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for s, l, b in zip(scores, labels, boxes):
|
||||||
|
score = s[s > threshold]
|
||||||
|
label = l[s > threshold]
|
||||||
|
box = b[s > threshold]
|
||||||
|
results.append({"scores": score, "labels": label, "boxes": box})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
# TODO: (Amy) Make compatible with other frameworks
|
# TODO: (Amy) Make compatible with other frameworks
|
||||||
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
|
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
|
|||||||
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||||
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
||||||
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized
|
possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the
|
||||||
bounding boxes.
|
unnormalized bounding boxes.
|
||||||
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
|
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
|
||||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||||||
@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
|
|||||||
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||||
values are normalized in [0, 1], relative to the size of each individual target image in the batch
|
values are normalized in [0, 1], relative to the size of each individual target image in the batch
|
||||||
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
|
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
|
||||||
unnormalized bounding boxes.
|
retrieve the unnormalized bounding boxes.
|
||||||
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||||||
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||||||
values are normalized in [0, 1], relative to the size of each individual query image in the batch
|
values are normalized in [0, 1], relative to the size of each individual query image in the batch
|
||||||
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
|
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
|
||||||
unnormalized bounding boxes.
|
retrieve the unnormalized bounding boxes.
|
||||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||||||
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
|
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
|
||||||
image embeddings for each patch.
|
image embeddings for each patch.
|
||||||
@@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
|
|
||||||
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
||||||
>>> target_sizes = torch.Tensor([image.size[::-1]])
|
>>> target_sizes = torch.Tensor([image.size[::-1]])
|
||||||
>>> # Convert outputs (bounding boxes and class logits) to COCO API
|
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
|
||||||
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
|
>>> results = processor.post_process_object_detection(
|
||||||
|
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
|
||||||
|
... )
|
||||||
|
|
||||||
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
|
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
|
||||||
>>> text = texts[i]
|
>>> text = texts[i]
|
||||||
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
|
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
|
||||||
|
|
||||||
>>> score_threshold = 0.1
|
|
||||||
>>> for box, score, label in zip(boxes, scores, labels):
|
>>> for box, score, label in zip(boxes, scores, labels):
|
||||||
... box = [round(i, 2) for i in box.tolist()]
|
... box = [round(i, 2) for i in box.tolist()]
|
||||||
... if score >= score_threshold:
|
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
||||||
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
|
||||||
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
|
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
|
||||||
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
|
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
|
||||||
```"""
|
```"""
|
||||||
|
|||||||
@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
return self.image_processor.post_process(*args, **kwargs)
|
return self.image_processor.post_process(*args, **kwargs)
|
||||||
|
|
||||||
|
def post_process_object_detection(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
|
||||||
|
to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.image_processor.post_process_object_detection(*args, **kwargs)
|
||||||
|
|
||||||
def post_process_image_guided_detection(self, *args, **kwargs):
|
def post_process_image_guided_detection(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
|
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
|
||||||
|
|||||||
@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
|
|||||||
for model_output in model_outputs:
|
for model_output in model_outputs:
|
||||||
label = model_output["candidate_label"]
|
label = model_output["candidate_label"]
|
||||||
model_output = BaseModelOutput(model_output)
|
model_output = BaseModelOutput(model_output)
|
||||||
outputs = self.feature_extractor.post_process(
|
outputs = self.feature_extractor.post_process_object_detection(
|
||||||
outputs=model_output, target_sizes=model_output["target_size"]
|
outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
|
||||||
)[0]
|
)[0]
|
||||||
keep = outputs["scores"] >= threshold
|
|
||||||
|
|
||||||
for index in keep.nonzero():
|
for index in outputs["scores"].nonzero():
|
||||||
score = outputs["scores"][index].item()
|
score = outputs["scores"][index].item()
|
||||||
box = self._get_bounding_box(outputs["boxes"][index][0])
|
box = self._get_bounding_box(outputs["boxes"][index][0])
|
||||||
|
|
||||||
|
|||||||
@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
|
|||||||
object_detector = pipeline("zero-shot-object-detection")
|
object_detector = pipeline("zero-shot-object-detection")
|
||||||
|
|
||||||
outputs = object_detector(
|
outputs = object_detector(
|
||||||
"http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"]
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
candidate_labels=["cat", "remote", "couch"],
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
|
|||||||
Reference in New Issue
Block a user