Beit postprocessing (#19099)
* add post_process_semantic_segmentation method to BeiTFeatureExtractor
This commit is contained in:
@@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
|
|||||||
|
|
||||||
[[autodoc]] BeitFeatureExtractor
|
[[autodoc]] BeitFeatureExtractor
|
||||||
- __call__
|
- __call__
|
||||||
|
- post_process_semantic_segmentation
|
||||||
|
|
||||||
## BeitModel
|
## BeitModel
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for BEiT."""
|
"""Feature extractor class for BEiT."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -27,9 +27,12 @@ from ...image_utils import (
|
|||||||
ImageInput,
|
ImageInput,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, logging
|
from ...utils import TensorType, is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|
||||||
|
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
|
||||||
|
"""
|
||||||
|
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`BeitForSemanticSegmentation`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
|
||||||
|
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
|
||||||
|
None, predictions will not be resized.
|
||||||
|
Returns:
|
||||||
|
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
|
||||||
|
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
|
||||||
|
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
|
"""
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||||
|
|
||||||
|
if target_sizes is not None and target_sizes.shape[1] != 2:
|
||||||
|
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||||
|
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
|
||||||
|
# Resize semantic segmentation maps
|
||||||
|
if target_sizes is not None:
|
||||||
|
if is_torch_tensor(target_sizes):
|
||||||
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
|
resized_maps = []
|
||||||
|
semantic_segmentation = semantic_segmentation.numpy()
|
||||||
|
|
||||||
|
for idx in range(len(semantic_segmentation)):
|
||||||
|
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
|
||||||
|
resized_maps.append(resized)
|
||||||
|
|
||||||
|
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
|
||||||
|
|
||||||
|
return semantic_segmentation
|
||||||
|
|||||||
Reference in New Issue
Block a user