OwlViT/Owlv2 post processing standardization (#34929)

* Refactor owlvit post_process_object_detection + add text_labels

* Fix copies in grounding dino

* Sync with Owlv2 postprocessing

* Add post_process_grounded_object_detection method to processor, deprecate post_process_object_detection

* Add test cases

* Move text_labels to processors only

* [run-slow] owlvit owlv2

* [run-slow] owlvit, owlv2

* Update snippets

* Update docs structure

* Update deprecated objects for check_repo

* Update docstring for post processing of image guided object detection
This commit is contained in:
Pavel Iakubovskii
2025-01-17 13:58:28 +00:00
committed by GitHub
parent add5f0566c
commit 94ae9a8da1
12 changed files with 467 additions and 188 deletions

View File

@@ -967,8 +967,9 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
text_labels = [["a photo of a cat", "a photo of a dog"]]
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
text=text_labels,
images=image,
max_length=16,
padding="max_length",
@@ -986,6 +987,21 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
# test post-processing
post_processed_output = processor.post_process_grounded_object_detection(outputs)
self.assertIsNone(post_processed_output[0]["text_labels"])
post_processed_output_with_text_labels = processor.post_process_grounded_object_detection(
outputs, text_labels=text_labels
)
objects_labels = post_processed_output_with_text_labels[0]["labels"].cpu().tolist()
self.assertListEqual(objects_labels, [0, 0])
objects_text_labels = post_processed_output_with_text_labels[0]["text_labels"]
self.assertIsNotNone(objects_text_labels)
self.assertListEqual(objects_text_labels, ["a photo of a cat", "a photo of a cat"])
@slow
def test_inference_one_shot_object_detection(self):
model_name = "google/owlvit-base-patch32"