Add semantic segmentation post-processing method to MobileViT (#19105)

* add post-processing method for semantic segmentation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Alara Dirik
2022-09-23 16:24:28 +03:00
committed by GitHub
parent 905635f5d3
commit 7e84723fe4
3 changed files with 73 additions and 2 deletions

View File

@@ -340,3 +340,27 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_post_processing_semantic_segmentation(self):
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
model = model.to(torch_device)
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(50, 60)])
expected_shape = torch.Size((50, 60))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((32, 32))
self.assertEqual(segmentation[0].shape, expected_shape)