Add post_process_semantic_segmentation method to SegFormer (#19072)

* add post_process_semantic_segmentation method to SegformerFeatureExtractor
* add test for semantic segmentation post-processing
This commit is contained in:
Alara Dirik
2022-09-21 11:40:35 +03:00
committed by GitHub
parent ef6741fe65
commit 9e95706648
3 changed files with 75 additions and 2 deletions

View File

@@ -395,3 +395,30 @@ class SegformerModelIntegrationTest(unittest.TestCase):
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
@slow
def test_post_processing_semantic_segmentation(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(
torch_device
)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device)
with torch.no_grad():
outputs = model(pixel_values)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((128, 128))
self.assertEqual(segmentation[0].shape, expected_shape)