Add post_process_semantic_segmentation method to DPTFeatureExtractor (#19107)

* add post-processing method for semantic segmentation

* add test for post-processing
This commit is contained in:
Alara Dirik
2022-09-21 15:15:26 +03:00
committed by GitHub
parent da6a1b6ca1
commit e7fdfc720a
3 changed files with 68 additions and 3 deletions

View File

@@ -298,3 +298,24 @@ class DPTModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4))
def test_post_processing_semantic_segmentation(self):
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade")
model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device)
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((480, 480))
self.assertEqual(segmentation[0].shape, expected_shape)