Grounding DINO Processor standardization (#34853)

* Add input ids to model output

* Add text preprocessing for processor

* Fix snippet

* Add test for equivalence

* Add type checking guard

* Fixing typehint

* Fix test for added `input_ids` in output

* Add deprecations and "text_labels" to output

* Adjust tests

* Fix test

* Update code examples

* Minor docs and code improvement

* Remove one-liner functions and rename class to CamelCase

* Update docstring

* Fixup
This commit is contained in:
Pavel Iakubovskii
2025-01-17 14:18:16 +00:00
committed by GitHub
parent 42b2857b01
commit 099d93d2e9
5 changed files with 217 additions and 80 deletions

View File

@@ -322,9 +322,9 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Object Detection model returns pred_logits and pred_boxes
# Object Detection model returns pred_logits and pred_boxes and input_ids
if model_class.__name__ == "GroundingDinoForObjectDetection":
correct_outlen += 2
correct_outlen += 3
self.assertEqual(out_len, correct_outlen)
@@ -653,7 +653,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
# verify postprocessing
results = processor.image_processor.post_process_object_detection(
outputs, threshold=0.35, target_sizes=[image.size[::-1]]
outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
expected_scores = torch.tensor([0.4526, 0.4082]).to(torch_device)
expected_slice_boxes = torch.tensor([344.8143, 23.1796, 637.4004, 373.8295]).to(torch_device)
@@ -667,14 +667,14 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
results = processor.post_process_grounded_object_detection(
outputs=outputs,
input_ids=encoding.input_ids,
box_threshold=0.35,
threshold=0.35,
text_threshold=0.3,
target_sizes=[image.size[::-1]],
target_sizes=[(image.height, image.width)],
)[0]
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-3))
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
self.assertListEqual(results["labels"], expected_labels)
self.assertListEqual(results["text_labels"], expected_labels)
@require_torch_accelerator
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
@@ -706,11 +706,11 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
# assert postprocessing
results_cpu = processor.image_processor.post_process_object_detection(
cpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]]
cpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
result_gpu = processor.image_processor.post_process_object_detection(
gpu_outputs, threshold=0.35, target_sizes=[image.size[::-1]]
gpu_outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))

View File

@@ -17,6 +17,7 @@ import os
import shutil
import tempfile
import unittest
from typing import Optional
import pytest
@@ -77,6 +78,20 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.embed_dim = 5
self.seq_length = 5
def prepare_text_inputs(self, batch_size: Optional[int] = None):
labels = ["a cat", "remote control"]
labels_longer = ["a person", "a car", "a dog", "a cat"]
if batch_size is None:
return labels
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return [labels]
return [labels, labels_longer] + [labels] * (batch_size - 2)
# Copied from tests.models.clip.test_processor_clip.CLIPProcessorTest.get_tokenizer with CLIP->Bert
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
@@ -98,6 +113,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return GroundingDinoObjectDetectionOutput(
pred_boxes=torch.rand(self.batch_size, self.num_queries, 4),
logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
input_ids=self.get_fake_grounding_dino_input_ids(),
)
def get_fake_grounding_dino_input_ids(self):
@@ -111,14 +127,11 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor = GroundingDinoProcessor(tokenizer=tokenizer, image_processor=image_processor)
grounding_dino_output = self.get_fake_grounding_dino_output()
grounding_dino_input_ids = self.get_fake_grounding_dino_input_ids()
post_processed = processor.post_process_grounded_object_detection(
grounding_dino_output, grounding_dino_input_ids
)
post_processed = processor.post_process_grounded_object_detection(grounding_dino_output)
self.assertEqual(len(post_processed), self.batch_size)
self.assertEqual(list(post_processed[0].keys()), ["scores", "labels", "boxes"])
self.assertEqual(list(post_processed[0].keys()), ["scores", "boxes", "text_labels", "labels"])
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
@@ -248,3 +261,26 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_text_preprocessing_equivalence(self):
processor = GroundingDinoProcessor.from_pretrained(self.tmpdirname)
# check for single input
formatted_labels = "a cat. a remote control."
labels = ["a cat", "a remote control"]
inputs1 = processor(text=formatted_labels, return_tensors="pt")
inputs2 = processor(text=labels, return_tensors="pt")
self.assertTrue(
torch.allclose(inputs1["input_ids"], inputs2["input_ids"]),
f"Input ids are not equal for single input: {inputs1['input_ids']} != {inputs2['input_ids']}",
)
# check for batched input
formatted_labels = ["a cat. a remote control.", "a car. a person."]
labels = [["a cat", "a remote control"], ["a car", "a person"]]
inputs1 = processor(text=formatted_labels, return_tensors="pt", padding=True)
inputs2 = processor(text=labels, return_tensors="pt", padding=True)
self.assertTrue(
torch.allclose(inputs1["input_ids"], inputs2["input_ids"]),
f"Input ids are not equal for batched input: {inputs1['input_ids']} != {inputs2['input_ids']}",
)