OmDet Turbo processor standardization (#34937)
* Fix docstring * Fix docstring * Add `classes_structure` to model output * Update omdet postprocessing * Adjust tests * Update code example in docs * Add deprecation to "classes" key in output * Types, docs * Fixing test * Fix missed clip_boxes * [run-slow] omdet_turbo * Apply suggestions from code review Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * Make CamelCase class --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
94ae9a8da1
commit
42b2857b01
@@ -646,9 +646,9 @@ def prepare_img():
|
||||
|
||||
|
||||
def prepare_text():
|
||||
classes = ["cat", "remote"]
|
||||
task = "Detect {}.".format(", ".join(classes))
|
||||
return classes, task
|
||||
text_labels = ["cat", "remote"]
|
||||
task = "Detect {}.".format(", ".join(text_labels))
|
||||
return text_labels, task
|
||||
|
||||
|
||||
def prepare_img_batched():
|
||||
@@ -660,14 +660,14 @@ def prepare_img_batched():
|
||||
|
||||
|
||||
def prepare_text_batched():
|
||||
classes1 = ["cat", "remote"]
|
||||
classes2 = ["boat"]
|
||||
classes3 = ["statue", "trees", "torch"]
|
||||
text_labels1 = ["cat", "remote"]
|
||||
text_labels2 = ["boat"]
|
||||
text_labels3 = ["statue", "trees", "torch"]
|
||||
|
||||
task1 = "Detect {}.".format(", ".join(classes1))
|
||||
task1 = "Detect {}.".format(", ".join(text_labels1))
|
||||
task2 = "Detect all the boat in the image."
|
||||
task3 = "Focus on the foreground, detect statue, torch and trees."
|
||||
return [classes1, classes2, classes3], [task1, task2, task3]
|
||||
return [text_labels1, text_labels2, text_labels3], [task1, task2, task3]
|
||||
|
||||
|
||||
@require_timm
|
||||
@@ -683,8 +683,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(torch_device)
|
||||
text_labels, task = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
@@ -706,7 +706,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
|
||||
@@ -715,8 +715,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
expected_text_labels = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["text_labels"], expected_text_labels)
|
||||
|
||||
def test_inference_object_detection_head_fp16(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(
|
||||
@@ -725,8 +725,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(
|
||||
text_labels, task = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.float16
|
||||
)
|
||||
|
||||
@@ -750,7 +750,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device, dtype=torch.float16)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(
|
||||
@@ -761,16 +761,16 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-1))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
expected_text_labels = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["text_labels"], expected_text_labels)
|
||||
|
||||
def test_inference_object_detection_head_no_task(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, _ = prepare_text()
|
||||
encoding = processor(images=image, text=classes, return_tensors="pt").to(torch_device)
|
||||
text_labels, _ = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
@@ -792,7 +792,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
|
||||
@@ -801,8 +801,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
expected_text_labels = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["text_labels"], expected_text_labels)
|
||||
|
||||
def test_inference_object_detection_head_batched(self):
|
||||
torch_device = "cpu"
|
||||
@@ -810,10 +810,10 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
processor = self.default_processor
|
||||
images_batched = prepare_img_batched()
|
||||
classes_batched, tasks_batched = prepare_text_batched()
|
||||
encoding = processor(images=images_batched, text=classes_batched, task=tasks_batched, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
text_labels_batched, tasks_batched = prepare_text_batched()
|
||||
encoding = processor(
|
||||
images=images_batched, text=text_labels_batched, task=tasks_batched, return_tensors="pt"
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
@@ -837,7 +837,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
classes=classes_batched,
|
||||
text_labels=text_labels_batched,
|
||||
target_sizes=[image.size[::-1] for image in images_batched],
|
||||
score_threshold=0.2,
|
||||
)
|
||||
@@ -858,19 +858,19 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
torch.allclose(torch.stack([result["boxes"][0, :] for result in results]), expected_slice_boxes, atol=1e-2)
|
||||
)
|
||||
|
||||
expected_classes = [
|
||||
expected_text_labels = [
|
||||
["remote", "cat", "remote", "cat"],
|
||||
["boat", "boat", "boat", "boat"],
|
||||
["statue", "trees", "trees", "torch", "statue", "statue"],
|
||||
]
|
||||
self.assertListEqual([result["classes"] for result in results], expected_classes)
|
||||
self.assertListEqual([result["text_labels"] for result in results], expected_text_labels)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt")
|
||||
text_labels, task = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt")
|
||||
# 1. run model on CPU
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
|
||||
@@ -894,10 +894,10 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify grounded postprocessing
|
||||
results_cpu = processor.post_process_grounded_object_detection(
|
||||
cpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
cpu_outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
result_gpu = processor.post_process_grounded_object_detection(
|
||||
gpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
gpu_outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-2))
|
||||
|
||||
@@ -76,10 +76,13 @@ class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_fake_omdet_turbo_output(self):
|
||||
classes = self.get_fake_omdet_turbo_classes()
|
||||
classes_structure = torch.tensor([len(sublist) for sublist in classes])
|
||||
torch.manual_seed(42)
|
||||
return OmDetTurboObjectDetectionOutput(
|
||||
decoder_coord_logits=torch.rand(self.batch_size, self.num_queries, 4),
|
||||
decoder_class_logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
|
||||
classes_structure=classes_structure,
|
||||
)
|
||||
|
||||
def get_fake_omdet_turbo_classes(self):
|
||||
@@ -99,7 +102,7 @@ class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(len(post_processed), self.batch_size)
|
||||
self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "classes"])
|
||||
self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "labels", "text_labels"])
|
||||
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
|
||||
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
|
||||
expected_scores = torch.tensor([0.7310, 0.6579, 0.6513, 0.6444, 0.6252])
|
||||
|
||||
Reference in New Issue
Block a user