Fix BeitFeatureExtractor postprocessing (#19119)
* return post-processed segmentations as list, add test * use torch to resize logits * fix assertion error if no target_size is specified
This commit is contained in:
@@ -226,43 +226,43 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|
||||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
|
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
||||||
"""
|
"""
|
||||||
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs ([`BeitForSemanticSegmentation`]):
|
outputs ([`BeitForSemanticSegmentation`]):
|
||||||
Raw outputs of the model.
|
Raw outputs of the model.
|
||||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
|
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
||||||
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
|
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
|
||||||
None, predictions will not be resized.
|
None, predictions will not be resized.
|
||||||
Returns:
|
Returns:
|
||||||
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
|
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||||
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
|
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||||
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||||
"""
|
"""
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
if len(logits) != len(target_sizes):
|
# Resize logits and compute semantic segmentation maps
|
||||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
|
||||||
|
|
||||||
if target_sizes is not None and target_sizes.shape[1] != 2:
|
|
||||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
|
||||||
|
|
||||||
semantic_segmentation = logits.argmax(dim=1)
|
|
||||||
|
|
||||||
# Resize semantic segmentation maps
|
|
||||||
if target_sizes is not None:
|
if target_sizes is not None:
|
||||||
|
if len(logits) != len(target_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
if is_torch_tensor(target_sizes):
|
if is_torch_tensor(target_sizes):
|
||||||
target_sizes = target_sizes.numpy()
|
target_sizes = target_sizes.numpy()
|
||||||
|
|
||||||
resized_maps = []
|
semantic_segmentation = []
|
||||||
semantic_segmentation = semantic_segmentation.numpy()
|
|
||||||
|
|
||||||
for idx in range(len(semantic_segmentation)):
|
for idx in range(len(logits)):
|
||||||
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
|
resized_logits = torch.nn.functional.interpolate(
|
||||||
resized_maps.append(resized)
|
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
|
semantic_map = resized_logits[0].argmax(dim=0)
|
||||||
|
semantic_segmentation.append(semantic_map)
|
||||||
|
else:
|
||||||
|
semantic_segmentation = logits.argmax(dim=1)
|
||||||
|
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||||
|
|
||||||
return semantic_segmentation
|
return semantic_segmentation
|
||||||
|
|||||||
@@ -455,3 +455,28 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_post_processing_semantic_segmentation(self):
|
||||||
|
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
|
||||||
|
|
||||||
|
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
image = Image.open(ds[0]["file"])
|
||||||
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
outputs.logits = outputs.logits.detach().cpu()
|
||||||
|
|
||||||
|
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
|
||||||
|
expected_shape = torch.Size((500, 300))
|
||||||
|
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||||
|
|
||||||
|
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
|
||||||
|
expected_shape = torch.Size((160, 160))
|
||||||
|
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user