diff --git a/docs/source/en/model_doc/dpt.mdx b/docs/source/en/model_doc/dpt.mdx index cdf009c6c8..fec1801661 100644 --- a/docs/source/en/model_doc/dpt.mdx +++ b/docs/source/en/model_doc/dpt.mdx @@ -37,6 +37,7 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi [[autodoc]] DPTFeatureExtractor - __call__ + - post_process_semantic_segmentation ## DPTModel diff --git a/src/transformers/models/dpt/feature_extraction_dpt.py b/src/transformers/models/dpt/feature_extraction_dpt.py index d4346b96f8..8f9f624a9b 100644 --- a/src/transformers/models/dpt/feature_extraction_dpt.py +++ b/src/transformers/models/dpt/feature_extraction_dpt.py @@ -14,13 +14,12 @@ # limitations under the License. """Feature extractor class for DPT.""" -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np from PIL import Image from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin -from ...file_utils import TensorType from ...image_utils import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, @@ -28,9 +27,12 @@ from ...image_utils import ( ImageInput, is_torch_tensor, ) -from ...utils import logging +from ...utils import TensorType, is_torch_available, logging +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) @@ -200,3 +202,44 @@ class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs + + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`DPTForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If left to + None, predictions will not be resized. + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 3266ea78a7..ef063f3617 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -298,3 +298,24 @@ class DPTModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4)) + + def test_post_processing_semantic_segmentation(self): + feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade") + model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device) + + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + outputs.logits = outputs.logits.detach().cpu() + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)]) + expected_shape = torch.Size((500, 300)) + self.assertEqual(segmentation[0].shape, expected_shape) + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs) + expected_shape = torch.Size((480, 480)) + self.assertEqual(segmentation[0].shape, expected_shape)