Beit postprocessing (#19099)

* add post_process_semantic_segmentation method to BeiTFeatureExtractor
This commit is contained in:
Alara Dirik
2022-09-20 10:41:56 +03:00
committed by GitHub
parent 261301d388
commit c81ebd1c39
2 changed files with 47 additions and 2 deletions

View File

@@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
[[autodoc]] BeitFeatureExtractor [[autodoc]] BeitFeatureExtractor
- __call__ - __call__
- post_process_semantic_segmentation
## BeitModel ## BeitModel

View File

@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for BEiT.""" """Feature extractor class for BEiT."""
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput, ImageInput,
is_torch_tensor, is_torch_tensor,
) )
from ...utils import TensorType, logging from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
semantic_segmentation = logits.argmax(dim=1)
# Resize semantic segmentation maps
if target_sizes is not None:
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
resized_maps = []
semantic_segmentation = semantic_segmentation.numpy()
for idx in range(len(semantic_segmentation)):
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
resized_maps.append(resized)
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
return semantic_segmentation