OwlViT/Owlv2 post processing standardization (#34929)
* Refactor owlvit post_process_object_detection + add text_labels * Fix copies in grounding dino * Sync with Owlv2 postprocessing * Add post_process_grounded_object_detection method to processor, deprecate post_process_object_detection * Add test cases * Move text_labels to processors only * [run-slow] owlvit owlv2 * [run-slow] owlvit, owlv2 * Update snippets * Update docs structure * Update deprecated objects for check_repo * Update docstring for post processing of image guided object detection
This commit is contained in:
committed by
GitHub
parent
add5f0566c
commit
94ae9a8da1
@@ -974,8 +974,9 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
text_labels = [["a photo of a cat", "a photo of a dog"]]
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
text=text_labels,
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
@@ -991,11 +992,31 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-21.413497, -21.612638], [-19.008193, -19.548841], [-20.958896, -21.382694]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
|
||||
resulted_slice_logits = outputs.logits[0, :3, :3]
|
||||
max_diff = torch.max(torch.abs(resulted_slice_logits - expected_slice_logits)).item()
|
||||
self.assertLess(max_diff, 3e-4)
|
||||
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.241309, 0.051896, 0.453267], [0.139474, 0.045701, 0.250660], [0.233022, 0.050479, 0.427671]],
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
resulted_slice_boxes = outputs.pred_boxes[0, :3, :3]
|
||||
max_diff = torch.max(torch.abs(resulted_slice_boxes - expected_slice_boxes)).item()
|
||||
self.assertLess(max_diff, 3e-4)
|
||||
|
||||
# test post-processing
|
||||
post_processed_output = processor.post_process_grounded_object_detection(outputs)
|
||||
self.assertIsNone(post_processed_output[0]["text_labels"])
|
||||
|
||||
post_processed_output_with_text_labels = processor.post_process_grounded_object_detection(
|
||||
outputs, text_labels=text_labels
|
||||
)
|
||||
|
||||
objects_labels = post_processed_output_with_text_labels[0]["labels"].cpu().tolist()
|
||||
self.assertListEqual(objects_labels, [0, 0])
|
||||
|
||||
objects_text_labels = post_processed_output_with_text_labels[0]["text_labels"]
|
||||
self.assertIsNotNone(objects_text_labels)
|
||||
self.assertListEqual(objects_text_labels, ["a photo of a cat", "a photo of a cat"])
|
||||
|
||||
@slow
|
||||
def test_inference_one_shot_object_detection(self):
|
||||
|
||||
Reference in New Issue
Block a user