[MaskFormer] Add support for ResNet backbone (#20483)
* Add SwinBackbone * Add hidden_states_before_downsampling support * Fix Swin tests * Improve conversion script * Add id2label mappings * Add vistas mapping * Update comments * Fix backbone * Improve tests * Extend conversion script * Add Swin conversion script * Fix style * Revert config attribute * Remove SwinBackbone from main init * Remove unused attribute * Use encoder for ResNet backbone * Improve conversion script and add integration test * Apply suggestion Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -18,7 +18,7 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto import CONFIG_MAPPING
|
||||||
from ..detr import DetrConfig
|
from ..detr import DetrConfig
|
||||||
from ..swin import SwinConfig
|
from ..swin import SwinConfig
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
model_type = "maskformer"
|
model_type = "maskformer"
|
||||||
attribute_map = {"hidden_size": "mask_feature_size"}
|
attribute_map = {"hidden_size": "mask_feature_size"}
|
||||||
backbones_supported = ["swin"]
|
backbones_supported = ["resnet", "swin"]
|
||||||
decoders_supported = ["detr"]
|
decoders_supported = ["detr"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
num_heads=[4, 8, 16, 32],
|
num_heads=[4, 8, 16, 32],
|
||||||
window_size=12,
|
window_size=12,
|
||||||
drop_path_rate=0.3,
|
drop_path_rate=0.3,
|
||||||
|
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
backbone_model_type = backbone_config.pop("model_type")
|
# verify that the backbone is supported
|
||||||
|
backbone_model_type = (
|
||||||
|
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type
|
||||||
|
)
|
||||||
if backbone_model_type not in self.backbones_supported:
|
if backbone_model_type not in self.backbones_supported:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Backbone {backbone_model_type} not supported, please use one of"
|
f"Backbone {backbone_model_type} not supported, please use one of"
|
||||||
f" {','.join(self.backbones_supported)}"
|
f" {','.join(self.backbones_supported)}"
|
||||||
)
|
)
|
||||||
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
|
if isinstance(backbone_config, dict):
|
||||||
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
|
||||||
if decoder_config is None:
|
if decoder_config is None:
|
||||||
# fall back to https://huggingface.co/facebook/detr-resnet-50
|
# fall back to https://huggingface.co/facebook/detr-resnet-50
|
||||||
decoder_config = DetrConfig()
|
decoder_config = DetrConfig()
|
||||||
else:
|
else:
|
||||||
decoder_type = decoder_config.pop("model_type")
|
# verify that the decoder is supported
|
||||||
|
decoder_type = (
|
||||||
|
decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
|
||||||
|
)
|
||||||
if decoder_type not in self.decoders_supported:
|
if decoder_type not in self.decoders_supported:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Transformer Decoder {decoder_type} not supported, please use one of"
|
f"Transformer Decoder {decoder_type} not supported, please use one of"
|
||||||
f" {','.join(self.decoders_supported)}"
|
f" {','.join(self.decoders_supported)}"
|
||||||
)
|
)
|
||||||
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
|
if isinstance(decoder_config, dict):
|
||||||
|
config_class = CONFIG_MAPPING[decoder_type]
|
||||||
|
decoder_config = config_class.from_dict(decoder_config)
|
||||||
|
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
self.decoder_config = decoder_config
|
self.decoder_config = decoder_config
|
||||||
@@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
[`MaskFormerConfig`]: An instance of a configuration object
|
[`MaskFormerConfig`]: An instance of a configuration object
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
backbone_config=backbone_config.to_dict(),
|
backbone_config=backbone_config,
|
||||||
decoder_config=decoder_config.to_dict(),
|
decoder_config=decoder_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
|
|||||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
out_features (`List[str]`, *optional*):
|
out_features (`List[str]`, *optional*):
|
||||||
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`.
|
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,390 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert MaskFormer checkpoints with ResNet backbone from the original repository. URL:
|
||||||
|
https://github.com/facebookresearch/MaskFormer"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, ResNetConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_maskformer_config(model_name: str):
|
||||||
|
if "resnet101c" in model_name:
|
||||||
|
# TODO add support for ResNet-C backbone, which uses a "deeplab" stem
|
||||||
|
raise NotImplementedError("To do")
|
||||||
|
elif "resnet101" in model_name:
|
||||||
|
backbone_config = ResNetConfig.from_pretrained(
|
||||||
|
"microsoft/resnet-101", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
backbone_config = ResNetConfig.from_pretrained(
|
||||||
|
"microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||||
|
)
|
||||||
|
config = MaskFormerConfig(backbone_config=backbone_config)
|
||||||
|
|
||||||
|
repo_id = "huggingface/label-files"
|
||||||
|
if "ade20k-full" in model_name:
|
||||||
|
config.num_labels = 847
|
||||||
|
filename = "maskformer-ade20k-full-id2label.json"
|
||||||
|
elif "ade" in model_name:
|
||||||
|
config.num_labels = 150
|
||||||
|
filename = "ade20k-id2label.json"
|
||||||
|
elif "coco-stuff" in model_name:
|
||||||
|
config.num_labels = 171
|
||||||
|
filename = "maskformer-coco-stuff-id2label.json"
|
||||||
|
elif "coco" in model_name:
|
||||||
|
# TODO
|
||||||
|
config.num_labels = 133
|
||||||
|
filename = "coco-panoptic-id2label.json"
|
||||||
|
elif "cityscapes" in model_name:
|
||||||
|
config.num_labels = 19
|
||||||
|
filename = "cityscapes-id2label.json"
|
||||||
|
elif "vistas" in model_name:
|
||||||
|
config.num_labels = 65
|
||||||
|
filename = "mapillary-vistas-id2label.json"
|
||||||
|
|
||||||
|
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||||
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
|
config.id2label = id2label
|
||||||
|
config.label2id = {v: k for k, v in id2label.items()}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_rename_keys(config):
|
||||||
|
rename_keys = []
|
||||||
|
# stem
|
||||||
|
# fmt: off
|
||||||
|
rename_keys.append(("backbone.stem.conv1.weight", "model.pixel_level_module.encoder.embedder.embedder.convolution.weight"))
|
||||||
|
rename_keys.append(("backbone.stem.conv1.norm.weight", "model.pixel_level_module.encoder.embedder.embedder.normalization.weight"))
|
||||||
|
rename_keys.append(("backbone.stem.conv1.norm.bias", "model.pixel_level_module.encoder.embedder.embedder.normalization.bias"))
|
||||||
|
rename_keys.append(("backbone.stem.conv1.norm.running_mean", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_mean"))
|
||||||
|
rename_keys.append(("backbone.stem.conv1.norm.running_var", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_var"))
|
||||||
|
# fmt: on
|
||||||
|
# stages
|
||||||
|
for stage_idx in range(len(config.backbone_config.depths)):
|
||||||
|
for layer_idx in range(config.backbone_config.depths[stage_idx]):
|
||||||
|
# shortcut
|
||||||
|
if layer_idx == 0:
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.weight",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.weight",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.bias",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_mean",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_var",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 3 convs
|
||||||
|
for i in range(3):
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.weight",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.weight",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.bias",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_mean",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rename_keys.append(
|
||||||
|
(
|
||||||
|
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_var",
|
||||||
|
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# FPN
|
||||||
|
# fmt: off
|
||||||
|
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
|
||||||
|
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
|
||||||
|
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
|
||||||
|
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
|
||||||
|
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Transformer decoder
|
||||||
|
# fmt: off
|
||||||
|
for idx in range(config.decoder_config.decoder_layers):
|
||||||
|
# self-attention out projection
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
|
||||||
|
# cross-attention out projection
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
|
||||||
|
# MLP 1
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
|
||||||
|
# MLP 2
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
|
||||||
|
# layernorm 1 (self-attention layernorm)
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
|
||||||
|
# layernorm 2 (cross-attention layernorm)
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
|
||||||
|
# layernorm 3 (final layernorm)
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
|
||||||
|
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# heads on top
|
||||||
|
# fmt: off
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
|
||||||
|
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
|
||||||
|
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
return rename_keys
|
||||||
|
|
||||||
|
|
||||||
|
def rename_key(dct, old, new):
|
||||||
|
val = dct.pop(old)
|
||||||
|
dct[new] = val
|
||||||
|
|
||||||
|
|
||||||
|
# we split up the matrix of each encoder layer into queries, keys and values
|
||||||
|
def read_in_decoder_q_k_v(state_dict, config):
|
||||||
|
# fmt: off
|
||||||
|
hidden_size = config.decoder_config.hidden_size
|
||||||
|
for idx in range(config.decoder_config.decoder_layers):
|
||||||
|
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||||
|
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
|
||||||
|
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
|
||||||
|
# next, add query, keys and values (in that order) to the state dict
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||||
|
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||||
|
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
|
||||||
|
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
|
||||||
|
# next, add query, keys and values (in that order) to the state dict
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img() -> torch.Tensor:
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
im = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
return im
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_maskformer_checkpoint(
|
||||||
|
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Copy/paste/tweak model's weights to our MaskFormer structure.
|
||||||
|
"""
|
||||||
|
config = get_maskformer_config(model_name)
|
||||||
|
|
||||||
|
# load original state_dict
|
||||||
|
with open(checkpoint_path, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
state_dict = data["model"]
|
||||||
|
|
||||||
|
# rename keys
|
||||||
|
rename_keys = create_rename_keys(config)
|
||||||
|
for src, dest in rename_keys:
|
||||||
|
rename_key(state_dict, src, dest)
|
||||||
|
read_in_decoder_q_k_v(state_dict, config)
|
||||||
|
|
||||||
|
# update to torch tensors
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
state_dict[key] = torch.from_numpy(value)
|
||||||
|
|
||||||
|
# load 🤗 model
|
||||||
|
model = MaskFormerForInstanceSegmentation(config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# verify results
|
||||||
|
image = prepare_img()
|
||||||
|
if "vistas" in model_name:
|
||||||
|
ignore_index = 65
|
||||||
|
elif "cityscapes" in model_name:
|
||||||
|
ignore_index = 65535
|
||||||
|
else:
|
||||||
|
ignore_index = 255
|
||||||
|
reduce_labels = True if "ade" in model_name else False
|
||||||
|
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
|
||||||
|
|
||||||
|
inputs = feature_extractor(image, return_tensors="pt")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
if model_name == "maskformer-resnet50-ade":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[6.7710, -0.1452, -3.5687], [1.9165, -1.0010, -1.8614], [3.6209, -0.2950, -1.3813]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet101-ade":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[4.0381, -1.1483, -1.9688], [2.7083, -1.9147, -2.2555], [3.4367, -1.3711, -2.1609]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet50-coco-stuff":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[3.2309, -3.0481, -2.8695], [5.4986, -5.4242, -2.4211], [6.2100, -5.2279, -2.7786]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet101-coco-stuff":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet101-cityscapes":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[-1.8861, -1.5465, 0.6749], [-2.3677, -1.6707, -0.0867], [-2.2314, -1.9530, -0.9132]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet50-vistas":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[-6.3917, -1.5216, -1.1392], [-5.5335, -4.5318, -1.8339], [-4.3576, -4.0301, 0.2162]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet50-ade20k-full":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[3.6146, -1.9367, -3.2534], [4.0099, 0.2027, -2.7576], [3.3913, -2.3644, -3.9519]]
|
||||||
|
)
|
||||||
|
elif model_name == "maskformer-resnet101-ade20k-full":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[3.2211, -1.6550, -2.7605], [2.8559, -2.4512, -2.9574], [2.6331, -2.6775, -2.1844]]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
|
||||||
|
print("Looks ok!")
|
||||||
|
|
||||||
|
if pytorch_dump_folder_path is not None:
|
||||||
|
print(f"Saving model and feature extractor of {model_name} to {pytorch_dump_folder_path}")
|
||||||
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
print(f"Pushing model and feature extractor of {model_name} to the hub...")
|
||||||
|
model.push_to_hub(f"facebook/{model_name}")
|
||||||
|
feature_extractor.push_to_hub(f"facebook/{model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
default="maskformer-resnet50-ade",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=[
|
||||||
|
"maskformer-resnet50-ade",
|
||||||
|
"maskformer-resnet101-ade",
|
||||||
|
"maskformer-resnet50-coco-stuff",
|
||||||
|
"maskformer-resnet101-coco-stuff",
|
||||||
|
"maskformer-resnet101-cityscapes",
|
||||||
|
"maskformer-resnet50-vistas",
|
||||||
|
"maskformer-resnet50-ade20k-full",
|
||||||
|
"maskformer-resnet101-ade20k-full",
|
||||||
|
],
|
||||||
|
help=("Name of the MaskFormer model you'd like to convert",),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help=("Path to the original pickle file (.pkl) of the original checkpoint.",),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_maskformer_checkpoint(
|
||||||
|
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
|
||||||
|
)
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert MaskFormer checkpoints with Swin backbone from the original repository. URL:
|
||||||
|
https://github.com/facebookresearch/MaskFormer"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, SwinConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_maskformer_config(model_name: str):
|
||||||
|
backbone_config = SwinConfig.from_pretrained(
|
||||||
|
"microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||||
|
)
|
||||||
|
config = MaskFormerConfig(backbone_config=backbone_config)
|
||||||
|
|
||||||
|
repo_id = "huggingface/label-files"
|
||||||
|
if "ade20k-full" in model_name:
|
||||||
|
# this should be ok
|
||||||
|
config.num_labels = 847
|
||||||
|
filename = "maskformer-ade20k-full-id2label.json"
|
||||||
|
elif "ade" in model_name:
|
||||||
|
# this should be ok
|
||||||
|
config.num_labels = 150
|
||||||
|
filename = "ade20k-id2label.json"
|
||||||
|
elif "coco-stuff" in model_name:
|
||||||
|
# this should be ok
|
||||||
|
config.num_labels = 171
|
||||||
|
filename = "maskformer-coco-stuff-id2label.json"
|
||||||
|
elif "coco" in model_name:
|
||||||
|
# TODO
|
||||||
|
config.num_labels = 133
|
||||||
|
filename = "coco-panoptic-id2label.json"
|
||||||
|
elif "cityscapes" in model_name:
|
||||||
|
# this should be ok
|
||||||
|
config.num_labels = 19
|
||||||
|
filename = "cityscapes-id2label.json"
|
||||||
|
elif "vistas" in model_name:
|
||||||
|
# this should be ok
|
||||||
|
config.num_labels = 65
|
||||||
|
filename = "mapillary-vistas-id2label.json"
|
||||||
|
|
||||||
|
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||||
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_rename_keys(config):
|
||||||
|
rename_keys = []
|
||||||
|
# stem
|
||||||
|
# fmt: off
|
||||||
|
rename_keys.append(("backbone.patch_embed.proj.weight", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.weight"))
|
||||||
|
rename_keys.append(("backbone.patch_embed.proj.bias", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.bias"))
|
||||||
|
rename_keys.append(("backbone.patch_embed.norm.weight", "model.pixel_level_module.encoder.model.embeddings.norm.weight"))
|
||||||
|
rename_keys.append(("backbone.patch_embed.norm.bias", "model.pixel_level_module.encoder.model.embeddings.norm.bias"))
|
||||||
|
# stages
|
||||||
|
for i in range(len(config.backbone_config.depths)):
|
||||||
|
for j in range(config.backbone_config.depths[i]):
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_bias_table", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_index", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.downsample.reduction.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.reduction.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.downsample.norm.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.weight"))
|
||||||
|
rename_keys.append((f"backbone.layers.{i}.downsample.norm.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.bias"))
|
||||||
|
rename_keys.append((f"backbone.norm{i}.weight", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.weight"))
|
||||||
|
rename_keys.append((f"backbone.norm{i}.bias", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.bias"))
|
||||||
|
|
||||||
|
# FPN
|
||||||
|
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
|
||||||
|
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
|
||||||
|
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
|
||||||
|
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
|
||||||
|
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
|
||||||
|
|
||||||
|
# Transformer decoder
|
||||||
|
for idx in range(config.decoder_config.decoder_layers):
|
||||||
|
# self-attention out projection
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
|
||||||
|
# cross-attention out projection
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
|
||||||
|
# MLP 1
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
|
||||||
|
# MLP 2
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
|
||||||
|
# layernorm 1 (self-attention layernorm)
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
|
||||||
|
# layernorm 2 (cross-attention layernorm)
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
|
||||||
|
# layernorm 3 (final layernorm)
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
|
||||||
|
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
|
||||||
|
|
||||||
|
# heads on top
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
|
||||||
|
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
|
||||||
|
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
|
||||||
|
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
|
||||||
|
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
return rename_keys
|
||||||
|
|
||||||
|
|
||||||
|
def rename_key(dct, old, new):
|
||||||
|
val = dct.pop(old)
|
||||||
|
dct[new] = val
|
||||||
|
|
||||||
|
|
||||||
|
# we split up the matrix of each encoder layer into queries, keys and values
|
||||||
|
def read_in_swin_q_k_v(state_dict, backbone_config):
|
||||||
|
num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
|
||||||
|
for i in range(len(backbone_config.depths)):
|
||||||
|
dim = num_features[i]
|
||||||
|
for j in range(backbone_config.depths[i]):
|
||||||
|
# fmt: off
|
||||||
|
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||||
|
in_proj_weight = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.weight")
|
||||||
|
in_proj_bias = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.bias")
|
||||||
|
# next, add query, keys and values (in that order) to the state dict
|
||||||
|
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
|
||||||
|
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
|
||||||
|
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
|
||||||
|
dim : dim * 2, :
|
||||||
|
]
|
||||||
|
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
|
||||||
|
dim : dim * 2
|
||||||
|
]
|
||||||
|
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
|
||||||
|
-dim :, :
|
||||||
|
]
|
||||||
|
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# we split up the matrix of each encoder layer into queries, keys and values
|
||||||
|
def read_in_decoder_q_k_v(state_dict, config):
|
||||||
|
# fmt: off
|
||||||
|
hidden_size = config.decoder_config.hidden_size
|
||||||
|
for idx in range(config.decoder_config.decoder_layers):
|
||||||
|
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||||
|
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
|
||||||
|
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
|
||||||
|
# next, add query, keys and values (in that order) to the state dict
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||||
|
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
|
||||||
|
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
|
||||||
|
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
|
||||||
|
# next, add query, keys and values (in that order) to the state dict
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
|
||||||
|
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img() -> torch.Tensor:
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
im = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
return im
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_maskformer_checkpoint(
|
||||||
|
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Copy/paste/tweak model's weights to our MaskFormer structure.
|
||||||
|
"""
|
||||||
|
config = get_maskformer_config(model_name)
|
||||||
|
|
||||||
|
# load original state_dict
|
||||||
|
with open(checkpoint_path, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
state_dict = data["model"]
|
||||||
|
|
||||||
|
# for name, param in state_dict.items():
|
||||||
|
# print(name, param.shape)
|
||||||
|
|
||||||
|
# rename keys
|
||||||
|
rename_keys = create_rename_keys(config)
|
||||||
|
for src, dest in rename_keys:
|
||||||
|
rename_key(state_dict, src, dest)
|
||||||
|
read_in_swin_q_k_v(state_dict, config.backbone_config)
|
||||||
|
read_in_decoder_q_k_v(state_dict, config)
|
||||||
|
|
||||||
|
# update to torch tensors
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
state_dict[key] = torch.from_numpy(value)
|
||||||
|
|
||||||
|
# load 🤗 model
|
||||||
|
model = MaskFormerForInstanceSegmentation(config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
print(name, param.shape)
|
||||||
|
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
|
assert missing_keys == [
|
||||||
|
"model.pixel_level_module.encoder.model.layernorm.weight",
|
||||||
|
"model.pixel_level_module.encoder.model.layernorm.bias",
|
||||||
|
]
|
||||||
|
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
|
||||||
|
|
||||||
|
# verify results
|
||||||
|
image = prepare_img()
|
||||||
|
if "vistas" in model_name:
|
||||||
|
ignore_index = 65
|
||||||
|
elif "cityscapes" in model_name:
|
||||||
|
ignore_index = 65535
|
||||||
|
else:
|
||||||
|
ignore_index = 255
|
||||||
|
reduce_labels = True if "ade" in model_name else False
|
||||||
|
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
|
||||||
|
|
||||||
|
inputs = feature_extractor(image, return_tensors="pt")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
print("Logits:", outputs.class_queries_logits[0, :3, :3])
|
||||||
|
|
||||||
|
if model_name == "maskformer-swin-tiny-ade":
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[[3.6353, -4.4770, -2.6065], [0.5081, -4.2394, -3.5343], [2.1909, -5.0353, -1.9323]]
|
||||||
|
)
|
||||||
|
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
|
||||||
|
print("Looks ok!")
|
||||||
|
|
||||||
|
if pytorch_dump_folder_path is not None:
|
||||||
|
print(f"Saving model and feature extractor to {pytorch_dump_folder_path}")
|
||||||
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
print("Pushing model and feature extractor to the hub...")
|
||||||
|
model.push_to_hub(f"nielsr/{model_name}")
|
||||||
|
feature_extractor.push_to_hub(f"nielsr/{model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
default="maskformer-swin-tiny-ade",
|
||||||
|
type=str,
|
||||||
|
help=("Name of the MaskFormer model you'd like to convert",),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path",
|
||||||
|
default="/Users/nielsrogge/Documents/MaskFormer_checkpoints/MaskFormer-Swin-tiny-ADE20k/model.pkl",
|
||||||
|
type=str,
|
||||||
|
help="Path to the original state dict (.pth file).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_maskformer_checkpoint(
|
||||||
|
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
|
||||||
|
)
|
||||||
@@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module):
|
|||||||
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
|
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet
|
|
||||||
class RegNetPreTrainedModel(PreTrainedModel):
|
class RegNetPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
|
|||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if isinstance(module, nn.Conv2d):
|
if isinstance(module, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
|||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
if isinstance(module, ResNetModel):
|
if isinstance(module, (ResNetModel, ResNetBackbone)):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
@@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
self.stage_names = config.stage_names
|
||||||
self.resnet = ResNetModel(config)
|
self.embedder = ResNetEmbeddings(config)
|
||||||
|
self.encoder = ResNetEncoder(config)
|
||||||
|
|
||||||
self.out_features = config.out_features
|
self.out_features = config.out_features
|
||||||
|
|
||||||
@@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
|
|||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)
|
embedding_output = self.embedder(pixel_values)
|
||||||
|
|
||||||
|
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
hidden_states = outputs.hidden_states
|
hidden_states = outputs.hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -320,16 +320,16 @@ def prepare_img():
|
|||||||
@require_vision
|
@require_vision
|
||||||
@slow
|
@slow
|
||||||
class MaskFormerModelIntegrationTest(unittest.TestCase):
|
class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||||
@cached_property
|
|
||||||
def model_checkpoints(self):
|
|
||||||
return "facebook/maskformer-swin-small-coco"
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def default_feature_extractor(self):
|
def default_feature_extractor(self):
|
||||||
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
|
return (
|
||||||
|
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||||
|
if is_vision_available()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
|
model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
|
||||||
feature_extractor = self.default_feature_extractor
|
feature_extractor = self.default_feature_extractor
|
||||||
image = prepare_img()
|
image = prepare_img()
|
||||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||||
@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_instance_segmentation_head(self):
|
def test_inference_instance_segmentation_head(self):
|
||||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
model = (
|
||||||
|
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
feature_extractor = self.default_feature_extractor
|
feature_extractor = self.default_feature_extractor
|
||||||
image = prepare_img()
|
image = prepare_img()
|
||||||
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||||
@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
# masks_queries_logits
|
# masks_queries_logits
|
||||||
masks_queries_logits = outputs.masks_queries_logits
|
masks_queries_logits = outputs.masks_queries_logits
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
|
masks_queries_logits.shape,
|
||||||
|
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||||
)
|
)
|
||||||
expected_slice = [
|
expected_slice = [
|
||||||
[-1.3737124, -1.7724937, -1.9364233],
|
[-1.3737124, -1.7724937, -1.9364233],
|
||||||
@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
# class_queries_logits
|
# class_queries_logits
|
||||||
class_queries_logits = outputs.class_queries_logits
|
class_queries_logits = outputs.class_queries_logits
|
||||||
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
|
self.assertEqual(
|
||||||
|
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||||
|
)
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[
|
[
|
||||||
[1.6512e00, -5.2572e00, -3.3519e00],
|
[1.6512e00, -5.2572e00, -3.3519e00],
|
||||||
@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
|
def test_inference_instance_segmentation_head_resnet_backbone(self):
|
||||||
|
model = (
|
||||||
|
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
feature_extractor = self.default_feature_extractor
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
|
||||||
|
inputs_shape = inputs["pixel_values"].shape
|
||||||
|
# check size is divisible by 32
|
||||||
|
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
|
||||||
|
# check size
|
||||||
|
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
# masks_queries_logits
|
||||||
|
masks_queries_logits = outputs.masks_queries_logits
|
||||||
|
self.assertEqual(
|
||||||
|
masks_queries_logits.shape,
|
||||||
|
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||||
|
)
|
||||||
|
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
|
||||||
|
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
||||||
|
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
# class_queries_logits
|
||||||
|
class_queries_logits = outputs.class_queries_logits
|
||||||
|
self.assertEqual(
|
||||||
|
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||||
|
)
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
|
||||||
|
).to(torch_device)
|
||||||
|
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
def test_with_segmentation_maps_and_loss(self):
|
def test_with_segmentation_maps_and_loss(self):
|
||||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
model = (
|
||||||
|
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
feature_extractor = self.default_feature_extractor
|
feature_extractor = self.default_feature_extractor
|
||||||
|
|
||||||
inputs = feature_extractor(
|
inputs = feature_extractor(
|
||||||
|
|||||||
Reference in New Issue
Block a user