[MaskFormer] Add support for ResNet backbone (#20483)
* Add SwinBackbone * Add hidden_states_before_downsampling support * Fix Swin tests * Improve conversion script * Add id2label mappings * Add vistas mapping * Update comments * Fix backbone * Improve tests * Extend conversion script * Add Swin conversion script * Fix style * Revert config attribute * Remove SwinBackbone from main init * Remove unused attribute * Use encoder for ResNet backbone * Improve conversion script and add integration test * Apply suggestion Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -320,16 +320,16 @@ def prepare_img():
|
||||
@require_vision
|
||||
@slow
|
||||
class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def model_checkpoints(self):
|
||||
return "facebook/maskformer-swin-small-coco"
|
||||
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
|
||||
return (
|
||||
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
|
||||
model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||
@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_inference_instance_segmentation_head(self):
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||
@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
# masks_queries_logits
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
self.assertEqual(
|
||||
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
|
||||
masks_queries_logits.shape,
|
||||
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||
)
|
||||
expected_slice = [
|
||||
[-1.3737124, -1.7724937, -1.9364233],
|
||||
@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
# class_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
|
||||
self.assertEqual(
|
||||
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[1.6512e00, -5.2572e00, -3.3519e00],
|
||||
@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_inference_instance_segmentation_head_resnet_backbone(self):
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||
inputs_shape = inputs["pixel_values"].shape
|
||||
# check size is divisible by 32
|
||||
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
|
||||
# check size
|
||||
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# masks_queries_logits
|
||||
masks_queries_logits = outputs.masks_queries_logits
|
||||
self.assertEqual(
|
||||
masks_queries_logits.shape,
|
||||
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||
)
|
||||
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
|
||||
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
# class_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
self.assertEqual(
|
||||
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_with_segmentation_maps_and_loss(self):
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
model = (
|
||||
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
inputs = feature_extractor(
|
||||
|
||||
Reference in New Issue
Block a user