[MaskFormer] Add support for ResNet backbone (#20483)

* Add SwinBackbone

* Add hidden_states_before_downsampling support

* Fix Swin tests

* Improve conversion script

* Add id2label mappings

* Add vistas mapping

* Update comments

* Fix backbone

* Improve tests

* Extend conversion script

* Add Swin conversion script

* Fix style

* Revert config attribute

* Remove SwinBackbone from main init

* Remove unused attribute

* Use encoder for ResNet backbone

* Improve conversion script and add integration test

* Apply suggestion

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-12-07 09:42:38 +01:00
committed by GitHub
parent 6c1a0b3931
commit b610c47f89
8 changed files with 808 additions and 24 deletions

View File

@@ -320,16 +320,16 @@ def prepare_img():
@require_vision
@slow
class MaskFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "facebook/maskformer-swin-small-coco"
@cached_property
def default_feature_extractor(self):
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
return (
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
if is_vision_available()
else None
)
def test_inference_no_head(self):
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
)
def test_inference_instance_segmentation_head(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [
[-1.3737124, -1.7724937, -1.9364233],
@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[
[1.6512e00, -5.2572e00, -3.3519e00],
@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_instance_segmentation_head_resnet_backbone(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(