fix for the output from post_process_panoptic_segmentation (#15916)
This commit is contained in:
committed by
GitHub
parent
7c45fe747f
commit
742273a52a
@@ -538,7 +538,6 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
|||||||
# create the area, since bool we just need to sum :)
|
# create the area, since bool we just need to sum :)
|
||||||
mask_k_area = mask_k.sum()
|
mask_k_area = mask_k.sum()
|
||||||
# this is the area of all the stuff in query k
|
# this is the area of all the stuff in query k
|
||||||
# TODO not 100%, why are the taking the k query here????
|
|
||||||
original_area = (mask_probs[k] >= 0.5).sum()
|
original_area = (mask_probs[k] >= 0.5).sum()
|
||||||
|
|
||||||
mask_does_exist = mask_k_area > 0 and original_area > 0
|
mask_does_exist = mask_k_area > 0 and original_area > 0
|
||||||
|
|||||||
@@ -404,3 +404,23 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
self.assertTrue(outputs.loss is not None)
|
self.assertTrue(outputs.loss is not None)
|
||||||
|
|
||||||
|
def test_panoptic_segmentation(self):
|
||||||
|
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||||
|
feature_extractor = self.default_feature_extractor
|
||||||
|
|
||||||
|
inputs = feature_extractor(
|
||||||
|
[np.zeros((3, 384, 384)), np.zeros((3, 384, 384))],
|
||||||
|
annotations=[
|
||||||
|
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||||
|
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||||
|
],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
panoptic_segmentation = feature_extractor.post_process_panoptic_segmentation(outputs)
|
||||||
|
|
||||||
|
self.assertTrue(len(panoptic_segmentation) == 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user