Fix BeitFeatureExtractor postprocessing (#19119)

* return post-processed segmentations as list, add test
* use torch to resize logits
* fix assertion error if no target_size is specified
This commit is contained in:
Alara Dirik
2022-09-20 18:53:40 +03:00
committed by GitHub
parent 36e356caa4
commit 36b9a99433
2 changed files with 47 additions and 22 deletions

View File

@@ -455,3 +455,28 @@ class BeitModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_post_processing_semantic_segmentation(self):
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
model = model.to(torch_device)
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((160, 160))
self.assertEqual(segmentation[0].shape, expected_shape)