Add post_process_semantic_segmentation method to SegFormer (#19072)
* add post_process_semantic_segmentation method to SegformerFeatureExtractor * add test for semantic segmentation post-processing
This commit is contained in:
@@ -395,3 +395,30 @@ class SegformerModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
|
||||
|
||||
@slow
|
||||
def test_post_processing_semantic_segmentation(self):
|
||||
# only resize + normalize
|
||||
feature_extractor = SegformerFeatureExtractor(
|
||||
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
|
||||
)
|
||||
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = encoded_inputs.pixel_values.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
|
||||
outputs.logits = outputs.logits.detach().cpu()
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
|
||||
expected_shape = torch.Size((500, 300))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((128, 128))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user