Add BeitForSemanticSegmentation (#14096)
* Add first draft * Make forward pass work * Improve conversion script * Add notebook that checks if it works * Add BeitForSemanticSegmentation to the tests * More improvements * Make BeitForSemanticSegmentation consistent with Segformer * Small bug fix * Add BeitForSemanticSegmentation to docs * Make sure model doesn't output hidden states when the user doesn't want to * Make it possible to convert the large model * Fix issue * Fix conversion script for large model * Add auxiliary_head option to semantic segmentation model * Apply suggestions from @sgugger's review * Apply suggestions from code review * Fix failing test Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -98,6 +98,13 @@ BeitForImageClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
BeitForSemanticSegmentation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BeitForSemanticSegmentation
|
||||
:members: forward
|
||||
|
||||
|
||||
FlaxBeitModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -638,6 +638,7 @@ if is_torch_available():
|
||||
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BeitForImageClassification",
|
||||
"BeitForMaskedImageModeling",
|
||||
"BeitForSemanticSegmentation",
|
||||
"BeitModel",
|
||||
"BeitPreTrainedModel",
|
||||
]
|
||||
@@ -2483,6 +2484,7 @@ if TYPE_CHECKING:
|
||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitModel,
|
||||
BeitPreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
||||
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BeitForImageClassification",
|
||||
"BeitForMaskedImageModeling",
|
||||
"BeitForSemanticSegmentation",
|
||||
"BeitModel",
|
||||
"BeitPreTrainedModel",
|
||||
]
|
||||
@@ -57,6 +58,7 @@ if TYPE_CHECKING:
|
||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitModel,
|
||||
BeitPreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig):
|
||||
use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
||||
CLS token, before applying the classification head.
|
||||
out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`):
|
||||
Indices of the feature maps to use for semantic segmentation.
|
||||
pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`):
|
||||
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
|
||||
use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to use an auxiliary head during training.
|
||||
auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4):
|
||||
Weight of the cross-entropy loss of the auxiliary head.
|
||||
auxiliary_channels (:obj:`int`, `optional`, defaults to 256):
|
||||
Number of channels to use in the auxiliary head.
|
||||
auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of convolutional layers to use in the auxiliary head.
|
||||
auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -117,6 +131,13 @@ class BeitConfig(PretrainedConfig):
|
||||
layer_scale_init_value=0.1,
|
||||
drop_path_rate=0.1,
|
||||
use_mean_pooling=True,
|
||||
out_indices=[3, 5, 7, 11],
|
||||
pool_scales=[1, 2, 3, 6],
|
||||
use_auxiliary_head=True,
|
||||
auxiliary_loss_weight=0.4,
|
||||
auxiliary_channels=256,
|
||||
auxiliary_num_convs=1,
|
||||
auxiliary_concat_input=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -142,3 +163,12 @@ class BeitConfig(PretrainedConfig):
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.use_mean_pooling = use_mean_pooling
|
||||
# decode head attributes (semantic segmentation)
|
||||
self.out_indices = out_indices
|
||||
self.pool_scales = pool_scales
|
||||
# auxiliary head attributes (semantic segmentation)
|
||||
self.use_auxiliary_head = use_auxiliary_head
|
||||
self.auxiliary_loss_weight = auxiliary_loss_weight
|
||||
self.auxiliary_channels = auxiliary_channels
|
||||
self.auxiliary_num_convs = auxiliary_num_convs
|
||||
self.auxiliary_concat_input = auxiliary_concat_input
|
||||
|
||||
@@ -20,11 +20,18 @@ import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling
|
||||
from transformers import (
|
||||
BeitConfig,
|
||||
BeitFeatureExtractor,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -33,27 +40,33 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False):
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight"))
|
||||
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias"))
|
||||
rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
("cls_token", "beit.embeddings.cls_token"),
|
||||
("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False):
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
elif is_semantic:
|
||||
# semantic segmentation classification heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
@@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False):
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False):
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "beit."
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias")
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2")
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
# relative_position bias table + index
|
||||
if not has_lm_head:
|
||||
# each layer has its own relative position bias
|
||||
table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"blocks.{i}.attn.relative_position_index")
|
||||
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||
|
||||
state_dict[
|
||||
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
] = table
|
||||
state_dict[
|
||||
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
] = index
|
||||
|
||||
|
||||
@@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
# define default BEiT configuration
|
||||
config = BeitConfig()
|
||||
has_lm_head = False
|
||||
is_semantic = False
|
||||
repo_id = "datasets/huggingface/label-files"
|
||||
# set config parameters based on URL
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
@@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
config.image_size = 384
|
||||
if "512" in checkpoint_url:
|
||||
config.image_size = 512
|
||||
elif "ade20k" in checkpoint_url:
|
||||
# fine-tuning
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.image_size = 640
|
||||
is_semantic = True
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'")
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
@@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
if "ade20k" in checkpoint_url:
|
||||
config.image_size = 640
|
||||
config.out_indices = [7, 11, 15, 23]
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head)
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
if is_semantic:
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("backbone.fpn"):
|
||||
key = key.replace("backbone.fpn", "fpn")
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
model = BeitForMaskedImageModeling(config)
|
||||
elif "ade20k" in checkpoint_url:
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
else:
|
||||
model = BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
|
||||
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
|
||||
if is_semantic:
|
||||
feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
@@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
)
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
assert logits.shape == expected_shape, "Shape of logits not as expected"
|
||||
print("Shape of logits:", logits.shape)
|
||||
if not has_lm_head:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected"
|
||||
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
|
||||
if is_semantic:
|
||||
assert torch.allclose(
|
||||
logits[0, :3, :3, :3], expected_logits, atol=1e-3
|
||||
), "First elements of logits not as expected"
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
assert torch.allclose(
|
||||
logits[0, :3], expected_logits, atol=1e-3
|
||||
), "First elements of logits not as expected"
|
||||
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
|
||||
@@ -163,6 +163,7 @@ class PatchEmbeddings(nn.Module):
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -499,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
@@ -851,3 +852,354 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class BeitConvModule(nn.Module):
|
||||
"""
|
||||
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
||||
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
||||
|
||||
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
dilation=dilation,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
output = self.activation(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class BeitPyramidPoolingModule(nn.ModuleList):
|
||||
"""
|
||||
Pyramid Pooling Module (PPM) used in PSPNet.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
|
||||
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales, in_channels, channels, align_corners):
|
||||
super().__init__()
|
||||
self.pool_scales = pool_scales
|
||||
self.align_corners = align_corners
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
for pool_scale in pool_scales:
|
||||
self.append(
|
||||
nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(pool_scale),
|
||||
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = nn.functional.interpolate(
|
||||
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
||||
|
||||
|
||||
class BeitUperHead(nn.Module):
|
||||
"""
|
||||
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
|
||||
<https://arxiv.org/abs/1807.10221>`_.
|
||||
|
||||
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
|
||||
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
|
||||
self.channels = config.hidden_size
|
||||
self.align_corners = False
|
||||
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
||||
|
||||
# PSP Module
|
||||
self.psp_modules = BeitPyramidPoolingModule(
|
||||
self.pool_scales,
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
self.bottleneck = BeitConvModule(
|
||||
self.in_channels[-1] + len(self.pool_scales) * self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
# FPN Module
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||
l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
|
||||
fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
self.fpn_bottleneck = BeitConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def psp_forward(self, inputs):
|
||||
x = inputs[-1]
|
||||
psp_outs = [x]
|
||||
psp_outs.extend(self.psp_modules(x))
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
output = self.bottleneck(psp_outs)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, encoder_hidden_states):
|
||||
# build laterals
|
||||
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
|
||||
|
||||
laterals.append(self.psp_forward(encoder_hidden_states))
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] += nn.functional.interpolate(
|
||||
laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
|
||||
# build outputs
|
||||
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
|
||||
# append psp feature
|
||||
fpn_outs.append(laterals[-1])
|
||||
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
fpn_outs[i] = nn.functional.interpolate(
|
||||
fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||
output = self.fpn_bottleneck(fpn_outs)
|
||||
output = self.classifier(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class BeitFCNHead(nn.Module):
|
||||
"""
|
||||
Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
|
||||
<https://arxiv.org/abs/1411.4038>`_.
|
||||
|
||||
Args:
|
||||
config (BeitConfig): Configuration.
|
||||
in_channels
|
||||
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||
|
||||
|
||||
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||
"""
|
||||
|
||||
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
|
||||
super().__init__()
|
||||
self.in_channels = config.hidden_size
|
||||
self.channels = config.auxiliary_channels
|
||||
self.num_convs = config.auxiliary_num_convs
|
||||
self.concat_input = config.auxiliary_concat_input
|
||||
self.in_index = in_index
|
||||
|
||||
conv_padding = (kernel_size // 2) * dilation
|
||||
convs = []
|
||||
convs.append(
|
||||
BeitConvModule(
|
||||
self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
||||
)
|
||||
)
|
||||
for i in range(self.num_convs - 1):
|
||||
convs.append(
|
||||
BeitConvModule(
|
||||
self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
||||
)
|
||||
)
|
||||
if self.num_convs == 0:
|
||||
self.convs = nn.Identity()
|
||||
else:
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if self.concat_input:
|
||||
self.conv_cat = BeitConvModule(
|
||||
self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
||||
|
||||
def forward(self, encoder_hidden_states):
|
||||
# just take the relevant feature maps
|
||||
hidden_states = encoder_hidden_states[self.in_index]
|
||||
output = self.convs(hidden_states)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
|
||||
output = self.classifier(output)
|
||||
return output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
|
||||
""",
|
||||
BEIT_START_DOCSTRING,
|
||||
)
|
||||
class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.beit = BeitModel(config, add_pooling_layer=False)
|
||||
|
||||
# FPNs
|
||||
self.fpn1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(config.hidden_size),
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||
)
|
||||
self.fpn2 = nn.Sequential(
|
||||
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||
)
|
||||
self.fpn3 = nn.Identity()
|
||||
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
# Semantic segmentation head(s)
|
||||
self.decode_head = BeitUperHead(config)
|
||||
self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def compute_loss(self, logits, auxiliary_logits, labels):
|
||||
# upsample logits to the images' original size
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
if auxiliary_logits is not None:
|
||||
upsampled_auxiliary_logits = nn.functional.interpolate(
|
||||
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
# compute weighted loss
|
||||
loss_fct = CrossEntropyLoss(ignore_index=255)
|
||||
main_loss = loss_fct(upsampled_logits, labels)
|
||||
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
|
||||
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
|
||||
|
||||
return loss
|
||||
|
||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
head_mask=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
|
||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
|
||||
(Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
|
||||
>>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = self.beit(
|
||||
pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True, # we need the intermediate hidden states
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
|
||||
|
||||
# only keep certain features, and reshape
|
||||
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
|
||||
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
|
||||
batch_size = pixel_values.shape[0]
|
||||
patch_resolution = self.config.image_size // self.config.patch_size
|
||||
features = [
|
||||
x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
|
||||
]
|
||||
|
||||
# apply FPNs
|
||||
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
||||
for i in range(len(features)):
|
||||
features[i] = ops[i](features[i])
|
||||
|
||||
logits = self.decode_head(features)
|
||||
auxiliary_logits = None
|
||||
if self.auxiliary_head is not None:
|
||||
auxiliary_logits = self.auxiliary_head(features)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.num_labels == 1:
|
||||
raise ValueError("The number of labels should be greater than one")
|
||||
else:
|
||||
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (logits,) + outputs[2:]
|
||||
else:
|
||||
output = (logits,) + outputs[3:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -606,6 +606,11 @@ class BeitForMaskedImageModeling:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BeitForSemanticSegmentation:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeitModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import BeitConfig
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.models.auto import get_values
|
||||
@@ -31,7 +33,13 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel
|
||||
from transformers import (
|
||||
MODEL_MAPPING,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitModel,
|
||||
)
|
||||
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||
|
||||
|
||||
@@ -53,7 +61,7 @@ class BeitModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@@ -63,6 +71,7 @@ class BeitModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
out_indices=[0, 1, 2, 3],
|
||||
):
|
||||
self.parent = parent
|
||||
self.vocab_size = 100
|
||||
@@ -82,6 +91,7 @@ class BeitModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.out_indices = out_indices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -109,6 +119,7 @@ class BeitModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
out_indices=self.out_indices,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -160,7 +171,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else ()
|
||||
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_pruning = False
|
||||
@@ -212,11 +225,14 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in get_values(MODEL_MAPPING):
|
||||
continue
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if model_class.__name__ == "BeitForMaskedImageModeling":
|
||||
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
||||
continue
|
||||
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
@@ -233,11 +249,17 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
|
||||
continue
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if model_class.__name__ == "BeitForMaskedImageModeling":
|
||||
if (
|
||||
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
|
||||
or not model_class.supports_gradient_checkpointing
|
||||
):
|
||||
continue
|
||||
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
@@ -298,7 +320,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
@@ -316,15 +339,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
self.assertEqual(out_len + 1, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
@@ -472,3 +489,32 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_class_idx = 2396
|
||||
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
|
||||
|
||||
@slow
|
||||
def test_inference_semantic_segmentation(self):
|
||||
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
|
||||
model = model.to(torch_device)
|
||||
|
||||
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
image = Image.open(ds[0]["file"])
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 150, 160, 160))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@@ -88,7 +88,7 @@ if is_torch_fx_available():
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
if "_range" in key or "_std" in key or "initializer_factor" in key:
|
||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||
setattr(configs_no_init, key, 1e-10)
|
||||
return configs_no_init
|
||||
|
||||
|
||||
@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"SegformerDecodeHead",
|
||||
"SegformerForSemanticSegmentation",
|
||||
"BeitForSemanticSegmentation",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"BeitForMaskedImageModeling",
|
||||
"CLIPTextModel",
|
||||
|
||||
Reference in New Issue
Block a user