Feature Extractor accepts segmentation_maps (#15964)
* feature extractor accepts * resolved conversations * added examples in test for ADE20K * num_classes -> num_labels * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * resolving conversations * resolving conversations * removed ADE * CI * minor changes in conversion script * reduce_labels in feature extractor * minor changes * correct preprocess for instace segmentation maps * minor changes * minor changes * CI * debugging * better padding * going to update labels inside the model * going to update labels inside the model * minor changes * tests * removed changes in feature_extractor_utils * conversation * conversation * example in feature extractor * more docstring in modeling * test * make style * doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c2f8eaf6bc
commit
c4deb7b3ae
@@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
num_labels=10,
|
||||
reduce_labels=True,
|
||||
ignore_index=255,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
||||
self.num_classes = 2
|
||||
self.height = 3
|
||||
self.width = 4
|
||||
self.num_labels = num_labels
|
||||
self.reduce_labels = reduce_labels
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
@@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"size_divisibility": self.size_divisibility,
|
||||
"num_labels": self.num_labels,
|
||||
"reduce_labels": self.reduce_labels,
|
||||
"ignore_index": self.ignore_index,
|
||||
}
|
||||
|
||||
def get_expected_values(self, image_inputs, batched=False):
|
||||
@@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "max_size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "ignore_index"))
|
||||
self.assertTrue(hasattr(feature_extractor, "num_labels"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
@@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
def test_equivalence_pad_and_create_pixel_mask(self):
|
||||
# Initialize feature_extractors
|
||||
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
|
||||
feature_extractor_2 = self.feature_extraction_class(
|
||||
do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes
|
||||
)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
@@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
|
||||
)
|
||||
|
||||
def comm_get_feature_extractor_inputs(self, with_annotations=False):
|
||||
def comm_get_feature_extractor_inputs(
|
||||
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
|
||||
):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# prepare image and target
|
||||
num_classes = 8
|
||||
batch_size = self.feature_extract_tester.batch_size
|
||||
num_labels = self.feature_extract_tester.num_labels
|
||||
annotations = None
|
||||
|
||||
if with_annotations:
|
||||
annotations = [
|
||||
{
|
||||
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32),
|
||||
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64),
|
||||
instance_id_to_semantic_id = None
|
||||
if with_segmentation_maps:
|
||||
high = num_labels
|
||||
if is_instance_map:
|
||||
high * 2
|
||||
labels_expanded = list(range(num_labels)) * 2
|
||||
instance_id_to_semantic_id = {
|
||||
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
|
||||
}
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)]
|
||||
if segmentation_type == "pil":
|
||||
annotations = [Image.fromarray(annotation) for annotation in annotations]
|
||||
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||
|
||||
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True)
|
||||
inputs = feature_extractor(
|
||||
image_inputs,
|
||||
annotations,
|
||||
return_tensors="pt",
|
||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||
pad_and_return_pixel_mask=True,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def test_init_without_params(self):
|
||||
pass
|
||||
|
||||
def test_with_size_divisibility(self):
|
||||
size_divisibilities = [8, 16, 32]
|
||||
weird_input_sizes = [(407, 802), (582, 1094)]
|
||||
@@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
|
||||
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
|
||||
|
||||
def test_call_with_numpy_annotations(self):
|
||||
num_classes = 8
|
||||
batch_size = self.feature_extract_tester.batch_size
|
||||
def test_call_with_segmentation_maps(self):
|
||||
def common(is_instance_map=False, segmentation_type=None):
|
||||
inputs = self.comm_get_feature_extractor_inputs(
|
||||
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
|
||||
)
|
||||
|
||||
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True)
|
||||
mask_labels = inputs["mask_labels"]
|
||||
class_labels = inputs["class_labels"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
# check the batch_size
|
||||
for el in inputs.values():
|
||||
self.assertEqual(el.shape[0], batch_size)
|
||||
# check the batch_size
|
||||
for mask_label, class_label in zip(mask_labels, class_labels):
|
||||
self.assertEqual(mask_label.shape[0], class_label.shape[0])
|
||||
# this ensure padding has happened
|
||||
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
mask_labels = inputs["mask_labels"]
|
||||
class_labels = inputs["class_labels"]
|
||||
|
||||
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2])
|
||||
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
|
||||
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
|
||||
self.assertEqual(mask_labels.shape[1], num_classes)
|
||||
common()
|
||||
common(is_instance_map=True)
|
||||
common(is_instance_map=False, segmentation_type="pil")
|
||||
common(is_instance_map=True, segmentation_type="pil")
|
||||
|
||||
def test_post_process_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class()
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_segmentation(outputs)
|
||||
|
||||
@@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
)
|
||||
|
||||
def test_post_process_semantic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class()
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
|
||||
@@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class()
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
|
||||
|
||||
|
||||
@@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_with_annotations_and_loss(self):
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
inputs = feature_extractor(
|
||||
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
|
||||
annotations=[
|
||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||
],
|
||||
segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
)
|
||||
|
||||
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
|
||||
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
|
||||
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
Reference in New Issue
Block a user