From 95408e995337f14bafa7f86e63d3885327f956c9 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 7 Mar 2023 13:34:04 +0100 Subject: [PATCH] [DETR, YOLOS] Fix device bug (#21974) * Fix integration test * Add test * Add test --- .../models/detr/image_processing_detr.py | 2 +- .../models/yolos/image_processing_yolos.py | 2 +- tests/models/detr/test_modeling_detr.py | 36 +++++++++++++++++++ tests/models/yolos/test_modeling_yolos.py | 15 +++++++- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index 71e80e8c74..6a0ace45b8 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -1563,7 +1563,7 @@ class DetrImageProcessor(BaseImageProcessor): else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index 6b6de49c68..150051fba6 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -1232,7 +1232,7 @@ class YolosImageProcessor(BaseImageProcessor): else: img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) boxes = boxes * scale_fct[:, None, :] results = [] diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index 7b032288b5..c8f782e593 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -539,6 +539,7 @@ class DetrModelIntegrationTests(unittest.TestCase): with torch.no_grad(): outputs = model(pixel_values, pixel_mask) + # verify outputs expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels + 1)) self.assertEqual(outputs.logits.shape, expected_shape_logits) expected_slice_logits = torch.tensor( @@ -553,6 +554,19 @@ class DetrModelIntegrationTests(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + # verify postprocessing + results = feature_extractor.post_process_object_detection( + outputs, threshold=0.3, target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor([0.9982, 0.9960, 0.9955, 0.9988, 0.9987]).to(torch_device) + expected_labels = [75, 75, 63, 17, 17] + expected_slice_boxes = torch.tensor([40.1633, 70.8115, 175.5471, 117.9841]).to(torch_device) + + self.assertEqual(len(results["scores"]), 5) + self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4)) + self.assertSequenceEqual(results["labels"].tolist(), expected_labels) + self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes)) + def test_inference_panoptic_segmentation_head(self): model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic").to(torch_device) @@ -565,6 +579,7 @@ class DetrModelIntegrationTests(unittest.TestCase): with torch.no_grad(): outputs = model(pixel_values, pixel_mask) + # verify outputs expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels + 1)) self.assertEqual(outputs.logits.shape, expected_shape_logits) expected_slice_logits = torch.tensor( @@ -585,3 +600,24 @@ class DetrModelIntegrationTests(unittest.TestCase): [[-7.7558, -10.8788, -11.9797], [-11.8881, -16.4329, -17.7451], [-14.7316, -19.7383, -20.3004]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, atol=1e-3)) + + # verify postprocessing + results = feature_extractor.post_process_panoptic_segmentation( + outputs, threshold=0.3, target_sizes=[image.size[::-1]] + )[0] + + expected_shape = torch.Size([480, 640]) + expected_slice_segmentation = torch.tensor([[4, 4, 4], [4, 4, 4], [4, 4, 4]], dtype=torch.int32).to( + torch_device + ) + expected_number_of_segments = 5 + expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096} + + number_of_unique_segments = len(torch.unique(results["segmentation"])) + self.assertTrue( + number_of_unique_segments, expected_number_of_segments + 1 + ) # we add 1 for the background class + self.assertTrue(results["segmentation"].shape, expected_shape) + self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4)) + self.assertTrue(len(results["segments_info"]), expected_number_of_segments) + self.assertDictEqual(results["segments_info"][0], expected_first_segment) diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index e85cffc3d0..209849394b 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -360,7 +360,7 @@ class YolosModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(inputs.pixel_values) - # verify the logits + # verify outputs expected_shape = torch.Size((1, 100, 92)) self.assertEqual(outputs.logits.shape, expected_shape) @@ -373,3 +373,16 @@ class YolosModelIntegrationTest(unittest.TestCase): ) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4)) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + + # verify postprocessing + results = feature_extractor.post_process_object_detection( + outputs, threshold=0.3, target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor([0.9994, 0.9790, 0.9964, 0.9972, 0.9861]).to(torch_device) + expected_labels = [75, 75, 17, 63, 17] + expected_slice_boxes = torch.tensor([335.0609, 79.3848, 375.4216, 187.2495]).to(torch_device) + + self.assertEqual(len(results["scores"]), 5) + self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-4)) + self.assertSequenceEqual(results["labels"].tolist(), expected_labels) + self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes))