Add semantic segmentation post-processing method to MobileViT (#19105)
* add post-processing method for semantic segmentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -66,6 +66,7 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The T
|
|||||||
|
|
||||||
[[autodoc]] MobileViTFeatureExtractor
|
[[autodoc]] MobileViTFeatureExtractor
|
||||||
- __call__
|
- __call__
|
||||||
|
- post_process_semantic_segmentation
|
||||||
|
|
||||||
## MobileViTModel
|
## MobileViTModel
|
||||||
|
|
||||||
|
|||||||
@@ -14,16 +14,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for MobileViT."""
|
"""Feature extractor class for MobileViT."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||||
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
|
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
|
||||||
from ...utils import TensorType, logging
|
from ...utils import TensorType, is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -151,3 +154,46 @@ class MobileViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMi
|
|||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||||
|
"""
|
||||||
|
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
|
||||||
|
PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`MobileViTForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`List[Tuple]`, *optional*):
|
||||||
|
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
|
||||||
|
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
`List[torch.Tensor]`:
|
||||||
|
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
||||||
|
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
||||||
|
`torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Resize logits and compute semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
semantic_segmentation = []
|
||||||
|
|
||||||
|
for idx in range(len(logits)):
|
||||||
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
||||||
|
|||||||
@@ -340,3 +340,27 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_post_processing_semantic_segmentation(self):
|
||||||
|
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
outputs.logits = outputs.logits.detach().cpu()
|
||||||
|
|
||||||
|
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(50, 60)])
|
||||||
|
expected_shape = torch.Size((50, 60))
|
||||||
|
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||||
|
|
||||||
|
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
|
||||||
|
expected_shape = torch.Size((32, 32))
|
||||||
|
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user