Add post_process_semantic_segmentation method to SegFormer (#19072)

* add post_process_semantic_segmentation method to SegformerFeatureExtractor
* add test for semantic segmentation post-processing
This commit is contained in:
Alara Dirik
2022-09-21 11:40:35 +03:00
committed by GitHub
parent ef6741fe65
commit 9e95706648
3 changed files with 75 additions and 2 deletions

View File

@@ -93,6 +93,7 @@ SegFormer's results on the segmentation datasets like ADE20k, refer to the [pape
[[autodoc]] SegformerFeatureExtractor [[autodoc]] SegformerFeatureExtractor
- __call__ - __call__
- post_process_semantic_segmentation
## SegformerModel ## SegformerModel

View File

@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for SegFormer.""" """Feature extractor class for SegFormer."""
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput, ImageInput,
is_torch_tensor, is_torch_tensor,
) )
from ...utils import TensorType, logging from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -211,3 +214,45 @@ class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMi
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@@ -395,3 +395,30 @@ class SegformerModelIntegrationTest(unittest.TestCase):
] ]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
@slow
def test_post_processing_semantic_segmentation(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(
torch_device
)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device)
with torch.no_grad():
outputs = model(pixel_values)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((128, 128))
self.assertEqual(segmentation[0].shape, expected_shape)