[Conditional, Deformable DETR] Add postprocessing methods (#19709)
* Add postprocessing methods * Update docs * Add fix * Add test * Add test for deformable detr postprocessing * Add post processing methods for segmentation * Update code examples * Add post_process to make the pipeline work * Apply updates Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -303,7 +303,6 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
|
||||
masks_path = pathlib.Path("./tests/fixtures/tests_samples/COCO/coco_panoptic")
|
||||
|
||||
# encode them
|
||||
# TODO replace by .from_pretrained microsoft/conditional-detr-resnet-50-panoptic
|
||||
feature_extractor = ConditionalDetrFeatureExtractor(format="coco_panoptic")
|
||||
encoding = feature_extractor(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
|
||||
|
||||
|
||||
@@ -492,6 +492,7 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, pixel_mask)
|
||||
|
||||
# verify logits + box predictions
|
||||
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape_logits)
|
||||
expected_slice_logits = torch.tensor(
|
||||
@@ -505,3 +506,16 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
|
||||
[[0.7733, 0.6576, 0.4496], [0.5171, 0.1184, 0.9094], [0.8846, 0.5647, 0.2486]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
# verify postprocessing
|
||||
results = feature_extractor.post_process_object_detection(
|
||||
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.8330, 0.8313, 0.8039, 0.6829, 0.5355])
|
||||
expected_labels = [75, 17, 17, 75, 63]
|
||||
expected_slice_boxes = torch.tensor([38.3089, 72.1022, 177.6293, 118.4512])
|
||||
|
||||
self.assertEqual(len(results["scores"]), 5)
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
||||
self.assertSequenceEqual(results["labels"].tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes))
|
||||
|
||||
@@ -565,6 +565,19 @@ class DeformableDetrModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4))
|
||||
|
||||
# verify postprocessing
|
||||
results = feature_extractor.post_process_object_detection(
|
||||
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7999, 0.7894, 0.6331, 0.4720, 0.4382])
|
||||
expected_labels = [17, 17, 75, 75, 63]
|
||||
expected_slice_boxes = torch.tensor([16.5028, 52.8390, 318.2544, 470.7841])
|
||||
|
||||
self.assertEqual(len(results["scores"]), 5)
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4))
|
||||
self.assertSequenceEqual(results["labels"].tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes))
|
||||
|
||||
def test_inference_object_detection_head_with_box_refine_two_stage(self):
|
||||
model = DeformableDetrForObjectDetection.from_pretrained(
|
||||
"SenseTime/deformable-detr-with-box-refine-two-stage"
|
||||
|
||||
Reference in New Issue
Block a user