Add semantic segmentation post-processing method to MobileViT (#19105)
* add post-processing method for semantic segmentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -340,3 +340,27 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_post_processing_semantic_segmentation(self):
|
||||
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
|
||||
model = model.to(torch_device)
|
||||
|
||||
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
|
||||
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
outputs.logits = outputs.logits.detach().cpu()
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(50, 60)])
|
||||
expected_shape = torch.Size((50, 60))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((32, 32))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user