OwlViT/Owlv2 post processing standardization (#34929)

* Refactor owlvit post_process_object_detection + add text_labels

* Fix copies in grounding dino

* Sync with Owlv2 postprocessing

* Add post_process_grounded_object_detection method to processor, deprecate post_process_object_detection

* Add test cases

* Move text_labels to processors only

* [run-slow] owlvit owlv2

* [run-slow] owlvit, owlv2

* Update snippets

* Update docs structure

* Update deprecated objects for check_repo

* Update docstring for post processing of image guided object detection
This commit is contained in:
Pavel Iakubovskii
2025-01-17 13:58:28 +00:00
committed by GitHub
parent add5f0566c
commit 94ae9a8da1
12 changed files with 467 additions and 188 deletions

View File

@@ -974,8 +974,9 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
text_labels = [["a photo of a cat", "a photo of a dog"]]
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
text=text_labels,
images=image,
max_length=16,
padding="max_length",
@@ -991,11 +992,31 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
expected_slice_logits = torch.tensor(
[[-21.413497, -21.612638], [-19.008193, -19.548841], [-20.958896, -21.382694]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
resulted_slice_logits = outputs.logits[0, :3, :3]
max_diff = torch.max(torch.abs(resulted_slice_logits - expected_slice_logits)).item()
self.assertLess(max_diff, 3e-4)
expected_slice_boxes = torch.tensor(
[[0.241309, 0.051896, 0.453267], [0.139474, 0.045701, 0.250660], [0.233022, 0.050479, 0.427671]],
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
resulted_slice_boxes = outputs.pred_boxes[0, :3, :3]
max_diff = torch.max(torch.abs(resulted_slice_boxes - expected_slice_boxes)).item()
self.assertLess(max_diff, 3e-4)
# test post-processing
post_processed_output = processor.post_process_grounded_object_detection(outputs)
self.assertIsNone(post_processed_output[0]["text_labels"])
post_processed_output_with_text_labels = processor.post_process_grounded_object_detection(
outputs, text_labels=text_labels
)
objects_labels = post_processed_output_with_text_labels[0]["labels"].cpu().tolist()
self.assertListEqual(objects_labels, [0, 0])
objects_text_labels = post_processed_output_with_text_labels[0]["text_labels"]
self.assertIsNotNone(objects_text_labels)
self.assertListEqual(objects_text_labels, ["a photo of a cat", "a photo of a cat"])
@slow
def test_inference_one_shot_object_detection(self):