Add BeitForSemanticSegmentation (#14096)
* Add first draft * Make forward pass work * Improve conversion script * Add notebook that checks if it works * Add BeitForSemanticSegmentation to the tests * More improvements * Make BeitForSemanticSegmentation consistent with Segformer * Small bug fix * Add BeitForSemanticSegmentation to docs * Make sure model doesn't output hidden states when the user doesn't want to * Make it possible to convert the large model * Fix issue * Fix conversion script for large model * Add auxiliary_head option to semantic segmentation model * Apply suggestions from @sgugger's review * Apply suggestions from code review * Fix failing test Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -98,6 +98,13 @@ BeitForImageClassification
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
BeitForSemanticSegmentation
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.BeitForSemanticSegmentation
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
FlaxBeitModel
|
FlaxBeitModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -638,6 +638,7 @@ if is_torch_available():
|
|||||||
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"BeitForImageClassification",
|
"BeitForImageClassification",
|
||||||
"BeitForMaskedImageModeling",
|
"BeitForMaskedImageModeling",
|
||||||
|
"BeitForSemanticSegmentation",
|
||||||
"BeitModel",
|
"BeitModel",
|
||||||
"BeitPreTrainedModel",
|
"BeitPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2483,6 +2484,7 @@ if TYPE_CHECKING:
|
|||||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
BeitForImageClassification,
|
BeitForImageClassification,
|
||||||
BeitForMaskedImageModeling,
|
BeitForMaskedImageModeling,
|
||||||
|
BeitForSemanticSegmentation,
|
||||||
BeitModel,
|
BeitModel,
|
||||||
BeitPreTrainedModel,
|
BeitPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
|||||||
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"BeitForImageClassification",
|
"BeitForImageClassification",
|
||||||
"BeitForMaskedImageModeling",
|
"BeitForMaskedImageModeling",
|
||||||
|
"BeitForSemanticSegmentation",
|
||||||
"BeitModel",
|
"BeitModel",
|
||||||
"BeitPreTrainedModel",
|
"BeitPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -57,6 +58,7 @@ if TYPE_CHECKING:
|
|||||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
BeitForImageClassification,
|
BeitForImageClassification,
|
||||||
BeitForMaskedImageModeling,
|
BeitForMaskedImageModeling,
|
||||||
|
BeitForSemanticSegmentation,
|
||||||
BeitModel,
|
BeitModel,
|
||||||
BeitPreTrainedModel,
|
BeitPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig):
|
|||||||
use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
||||||
CLS token, before applying the classification head.
|
CLS token, before applying the classification head.
|
||||||
|
out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`):
|
||||||
|
Indices of the feature maps to use for semantic segmentation.
|
||||||
|
pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`):
|
||||||
|
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
|
||||||
|
use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether to use an auxiliary head during training.
|
||||||
|
auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4):
|
||||||
|
Weight of the cross-entropy loss of the auxiliary head.
|
||||||
|
auxiliary_channels (:obj:`int`, `optional`, defaults to 256):
|
||||||
|
Number of channels to use in the auxiliary head.
|
||||||
|
auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
Number of convolutional layers to use in the auxiliary head.
|
||||||
|
auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -117,6 +131,13 @@ class BeitConfig(PretrainedConfig):
|
|||||||
layer_scale_init_value=0.1,
|
layer_scale_init_value=0.1,
|
||||||
drop_path_rate=0.1,
|
drop_path_rate=0.1,
|
||||||
use_mean_pooling=True,
|
use_mean_pooling=True,
|
||||||
|
out_indices=[3, 5, 7, 11],
|
||||||
|
pool_scales=[1, 2, 3, 6],
|
||||||
|
use_auxiliary_head=True,
|
||||||
|
auxiliary_loss_weight=0.4,
|
||||||
|
auxiliary_channels=256,
|
||||||
|
auxiliary_num_convs=1,
|
||||||
|
auxiliary_concat_input=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -142,3 +163,12 @@ class BeitConfig(PretrainedConfig):
|
|||||||
self.layer_scale_init_value = layer_scale_init_value
|
self.layer_scale_init_value = layer_scale_init_value
|
||||||
self.drop_path_rate = drop_path_rate
|
self.drop_path_rate = drop_path_rate
|
||||||
self.use_mean_pooling = use_mean_pooling
|
self.use_mean_pooling = use_mean_pooling
|
||||||
|
# decode head attributes (semantic segmentation)
|
||||||
|
self.out_indices = out_indices
|
||||||
|
self.pool_scales = pool_scales
|
||||||
|
# auxiliary head attributes (semantic segmentation)
|
||||||
|
self.use_auxiliary_head = use_auxiliary_head
|
||||||
|
self.auxiliary_loss_weight = auxiliary_loss_weight
|
||||||
|
self.auxiliary_channels = auxiliary_channels
|
||||||
|
self.auxiliary_num_convs = auxiliary_num_convs
|
||||||
|
self.auxiliary_concat_input = auxiliary_concat_input
|
||||||
|
|||||||
@@ -20,11 +20,18 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling
|
from transformers import (
|
||||||
|
BeitConfig,
|
||||||
|
BeitFeatureExtractor,
|
||||||
|
BeitForImageClassification,
|
||||||
|
BeitForMaskedImageModeling,
|
||||||
|
BeitForSemanticSegmentation,
|
||||||
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -33,27 +40,33 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||||
def create_rename_keys(config, has_lm_head=False):
|
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||||
|
prefix = "backbone." if is_semantic else ""
|
||||||
|
|
||||||
rename_keys = []
|
rename_keys = []
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||||
rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||||
rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||||
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight"))
|
rename_keys.append(
|
||||||
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias"))
|
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||||
rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
)
|
||||||
rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
rename_keys.append(
|
||||||
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||||
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
)
|
||||||
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||||
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||||
|
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||||
|
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||||
|
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||||
|
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||||
|
|
||||||
# projection layer + position embeddings
|
# projection layer + position embeddings
|
||||||
rename_keys.extend(
|
rename_keys.extend(
|
||||||
[
|
[
|
||||||
("cls_token", "beit.embeddings.cls_token"),
|
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||||
("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||||
("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False):
|
|||||||
("norm.bias", "layernorm.bias"),
|
("norm.bias", "layernorm.bias"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
elif is_semantic:
|
||||||
|
# semantic segmentation classification heads
|
||||||
|
rename_keys.extend(
|
||||||
|
[
|
||||||
|
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||||
|
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||||
|
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||||
|
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# layernorm + classification head
|
# layernorm + classification head
|
||||||
rename_keys.extend(
|
rename_keys.extend(
|
||||||
@@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False):
|
|||||||
|
|
||||||
|
|
||||||
# we split up the matrix of each encoder layer into queries, keys and values
|
# we split up the matrix of each encoder layer into queries, keys and values
|
||||||
def read_in_q_k_v(state_dict, config, has_lm_head=False):
|
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
prefix = "beit."
|
prefix = "backbone." if is_semantic else ""
|
||||||
# queries, keys and values
|
# queries, keys and values
|
||||||
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||||
q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias")
|
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||||
v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias")
|
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||||
|
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||||
: config.hidden_size, :
|
: config.hidden_size, :
|
||||||
]
|
]
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||||
config.hidden_size : config.hidden_size * 2, :
|
config.hidden_size : config.hidden_size * 2, :
|
||||||
]
|
]
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||||
-config.hidden_size :, :
|
-config.hidden_size :, :
|
||||||
]
|
]
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||||
|
|
||||||
# gamma_1 and gamma_2
|
# gamma_1 and gamma_2
|
||||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||||
gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1")
|
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||||
gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2")
|
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||||
|
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1
|
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||||
state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2
|
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||||
|
|
||||||
# relative_position bias table + index
|
# relative_position bias table + index
|
||||||
if not has_lm_head:
|
if not has_lm_head:
|
||||||
# each layer has its own relative position bias
|
# each layer has its own relative position bias
|
||||||
table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table")
|
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||||
index = state_dict.pop(f"blocks.{i}.attn.relative_position_index")
|
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||||
|
|
||||||
state_dict[
|
state_dict[
|
||||||
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||||
] = table
|
] = table
|
||||||
state_dict[
|
state_dict[
|
||||||
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||||
] = index
|
] = index
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
|||||||
# define default BEiT configuration
|
# define default BEiT configuration
|
||||||
config = BeitConfig()
|
config = BeitConfig()
|
||||||
has_lm_head = False
|
has_lm_head = False
|
||||||
|
is_semantic = False
|
||||||
repo_id = "datasets/huggingface/label-files"
|
repo_id = "datasets/huggingface/label-files"
|
||||||
# set config parameters based on URL
|
# set config parameters based on URL
|
||||||
if checkpoint_url[-9:-4] == "pt22k":
|
if checkpoint_url[-9:-4] == "pt22k":
|
||||||
@@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
|||||||
config.image_size = 384
|
config.image_size = 384
|
||||||
if "512" in checkpoint_url:
|
if "512" in checkpoint_url:
|
||||||
config.image_size = 512
|
config.image_size = 512
|
||||||
|
elif "ade20k" in checkpoint_url:
|
||||||
|
# fine-tuning
|
||||||
|
config.use_relative_position_bias = True
|
||||||
|
config.num_labels = 150
|
||||||
|
filename = "ade20k-id2label.json"
|
||||||
|
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
|
||||||
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
|
config.id2label = id2label
|
||||||
|
config.label2id = {v: k for k, v in id2label.items()}
|
||||||
|
config.image_size = 640
|
||||||
|
is_semantic = True
|
||||||
else:
|
else:
|
||||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'")
|
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||||
|
|
||||||
# size of the architecture
|
# size of the architecture
|
||||||
if "base" in checkpoint_url:
|
if "base" in checkpoint_url:
|
||||||
@@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
|||||||
config.intermediate_size = 4096
|
config.intermediate_size = 4096
|
||||||
config.num_hidden_layers = 24
|
config.num_hidden_layers = 24
|
||||||
config.num_attention_heads = 16
|
config.num_attention_heads = 16
|
||||||
|
if "ade20k" in checkpoint_url:
|
||||||
|
config.image_size = 640
|
||||||
|
config.out_indices = [7, 11, 15, 23]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||||
|
|
||||||
# load state_dict of original model, remove and rename some keys
|
# load state_dict of original model, remove and rename some keys
|
||||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]
|
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head)
|
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||||
|
|
||||||
|
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||||
for src, dest in rename_keys:
|
for src, dest in rename_keys:
|
||||||
rename_key(state_dict, src, dest)
|
rename_key(state_dict, src, dest)
|
||||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head)
|
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||||
|
if is_semantic:
|
||||||
|
# add prefix to decoder keys
|
||||||
|
for key, val in state_dict.copy().items():
|
||||||
|
val = state_dict.pop(key)
|
||||||
|
if key.startswith("backbone.fpn"):
|
||||||
|
key = key.replace("backbone.fpn", "fpn")
|
||||||
|
state_dict[key] = val
|
||||||
|
|
||||||
# load HuggingFace model
|
# load HuggingFace model
|
||||||
if checkpoint_url[-9:-4] == "pt22k":
|
if checkpoint_url[-9:-4] == "pt22k":
|
||||||
model = BeitForMaskedImageModeling(config)
|
model = BeitForMaskedImageModeling(config)
|
||||||
|
elif "ade20k" in checkpoint_url:
|
||||||
|
model = BeitForSemanticSegmentation(config)
|
||||||
else:
|
else:
|
||||||
model = BeitForImageClassification(config)
|
model = BeitForImageClassification(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# Check outputs on an image
|
# Check outputs on an image
|
||||||
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
|
if is_semantic:
|
||||||
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
|
feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False)
|
||||||
|
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
image = Image.open(ds[0]["file"])
|
||||||
|
else:
|
||||||
|
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
|
||||||
|
image = prepare_img()
|
||||||
|
|
||||||
|
encoding = feature_extractor(images=image, return_tensors="pt")
|
||||||
pixel_values = encoding["pixel_values"]
|
pixel_values = encoding["pixel_values"]
|
||||||
|
|
||||||
outputs = model(pixel_values)
|
outputs = model(pixel_values)
|
||||||
@@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
|||||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||||
expected_class_idx = 761
|
expected_class_idx = 761
|
||||||
|
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||||
|
expected_shape = (1, 150, 160, 160)
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||||
|
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||||
|
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||||
|
expected_shape = (1, 150, 160, 160)
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||||
|
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||||
|
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Can't verify logits as model is not supported")
|
raise ValueError("Can't verify logits as model is not supported")
|
||||||
|
|
||||||
assert logits.shape == expected_shape, "Shape of logits not as expected"
|
assert logits.shape == expected_shape, "Shape of logits not as expected"
|
||||||
print("Shape of logits:", logits.shape)
|
|
||||||
if not has_lm_head:
|
if not has_lm_head:
|
||||||
print("Predicted class idx:", logits.argmax(-1).item())
|
if is_semantic:
|
||||||
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected"
|
assert torch.allclose(
|
||||||
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
|
logits[0, :3, :3, :3], expected_logits, atol=1e-3
|
||||||
|
), "First elements of logits not as expected"
|
||||||
|
else:
|
||||||
|
print("Predicted class idx:", logits.argmax(-1).item())
|
||||||
|
assert torch.allclose(
|
||||||
|
logits[0, :3], expected_logits, atol=1e-3
|
||||||
|
), "First elements of logits not as expected"
|
||||||
|
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
|
||||||
|
|
||||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class PatchEmbeddings(nn.Module):
|
|||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
)
|
)
|
||||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -499,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
@@ -851,3 +852,354 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BeitConvModule(nn.Module):
|
||||||
|
"""
|
||||||
|
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
||||||
|
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = self.conv(input)
|
||||||
|
output = self.bn(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BeitPyramidPoolingModule(nn.ModuleList):
|
||||||
|
"""
|
||||||
|
Pyramid Pooling Module (PPM) used in PSPNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||||
|
Module.
|
||||||
|
in_channels (int): Input channels.
|
||||||
|
channels (int): Channels after modules, before conv_seg.
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pool_scales, in_channels, channels, align_corners):
|
||||||
|
super().__init__()
|
||||||
|
self.pool_scales = pool_scales
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.channels = channels
|
||||||
|
for pool_scale in pool_scales:
|
||||||
|
self.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d(pool_scale),
|
||||||
|
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ppm_outs = []
|
||||||
|
for ppm in self:
|
||||||
|
ppm_out = ppm(x)
|
||||||
|
upsampled_ppm_out = nn.functional.interpolate(
|
||||||
|
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
|
||||||
|
)
|
||||||
|
ppm_outs.append(upsampled_ppm_out)
|
||||||
|
return ppm_outs
|
||||||
|
|
||||||
|
|
||||||
|
class BeitUperHead(nn.Module):
|
||||||
|
"""
|
||||||
|
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
|
||||||
|
<https://arxiv.org/abs/1807.10221>`_.
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
|
||||||
|
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
|
||||||
|
self.channels = config.hidden_size
|
||||||
|
self.align_corners = False
|
||||||
|
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
||||||
|
|
||||||
|
# PSP Module
|
||||||
|
self.psp_modules = BeitPyramidPoolingModule(
|
||||||
|
self.pool_scales,
|
||||||
|
self.in_channels[-1],
|
||||||
|
self.channels,
|
||||||
|
align_corners=self.align_corners,
|
||||||
|
)
|
||||||
|
self.bottleneck = BeitConvModule(
|
||||||
|
self.in_channels[-1] + len(self.pool_scales) * self.channels,
|
||||||
|
self.channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
# FPN Module
|
||||||
|
self.lateral_convs = nn.ModuleList()
|
||||||
|
self.fpn_convs = nn.ModuleList()
|
||||||
|
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||||
|
l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
|
||||||
|
fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
|
||||||
|
self.lateral_convs.append(l_conv)
|
||||||
|
self.fpn_convs.append(fpn_conv)
|
||||||
|
|
||||||
|
self.fpn_bottleneck = BeitConvModule(
|
||||||
|
len(self.in_channels) * self.channels,
|
||||||
|
self.channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def psp_forward(self, inputs):
|
||||||
|
x = inputs[-1]
|
||||||
|
psp_outs = [x]
|
||||||
|
psp_outs.extend(self.psp_modules(x))
|
||||||
|
psp_outs = torch.cat(psp_outs, dim=1)
|
||||||
|
output = self.bottleneck(psp_outs)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, encoder_hidden_states):
|
||||||
|
# build laterals
|
||||||
|
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
|
||||||
|
|
||||||
|
laterals.append(self.psp_forward(encoder_hidden_states))
|
||||||
|
|
||||||
|
# build top-down path
|
||||||
|
used_backbone_levels = len(laterals)
|
||||||
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
|
prev_shape = laterals[i - 1].shape[2:]
|
||||||
|
laterals[i - 1] += nn.functional.interpolate(
|
||||||
|
laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
|
||||||
|
)
|
||||||
|
|
||||||
|
# build outputs
|
||||||
|
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
|
||||||
|
# append psp feature
|
||||||
|
fpn_outs.append(laterals[-1])
|
||||||
|
|
||||||
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
|
fpn_outs[i] = nn.functional.interpolate(
|
||||||
|
fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
|
||||||
|
)
|
||||||
|
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||||
|
output = self.fpn_bottleneck(fpn_outs)
|
||||||
|
output = self.classifier(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BeitFCNHead(nn.Module):
|
||||||
|
"""
|
||||||
|
Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
|
||||||
|
<https://arxiv.org/abs/1411.4038>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (BeitConfig): Configuration.
|
||||||
|
in_channels
|
||||||
|
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||||
|
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||||
|
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = config.hidden_size
|
||||||
|
self.channels = config.auxiliary_channels
|
||||||
|
self.num_convs = config.auxiliary_num_convs
|
||||||
|
self.concat_input = config.auxiliary_concat_input
|
||||||
|
self.in_index = in_index
|
||||||
|
|
||||||
|
conv_padding = (kernel_size // 2) * dilation
|
||||||
|
convs = []
|
||||||
|
convs.append(
|
||||||
|
BeitConvModule(
|
||||||
|
self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(self.num_convs - 1):
|
||||||
|
convs.append(
|
||||||
|
BeitConvModule(
|
||||||
|
self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.num_convs == 0:
|
||||||
|
self.convs = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.convs = nn.Sequential(*convs)
|
||||||
|
if self.concat_input:
|
||||||
|
self.conv_cat = BeitConvModule(
|
||||||
|
self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, encoder_hidden_states):
|
||||||
|
# just take the relevant feature maps
|
||||||
|
hidden_states = encoder_hidden_states[self.in_index]
|
||||||
|
output = self.convs(hidden_states)
|
||||||
|
if self.concat_input:
|
||||||
|
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
|
||||||
|
output = self.classifier(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
|
||||||
|
""",
|
||||||
|
BEIT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.beit = BeitModel(config, add_pooling_layer=False)
|
||||||
|
|
||||||
|
# FPNs
|
||||||
|
self.fpn1 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||||
|
nn.BatchNorm2d(config.hidden_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
self.fpn2 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
self.fpn3 = nn.Identity()
|
||||||
|
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
# Semantic segmentation head(s)
|
||||||
|
self.decode_head = BeitUperHead(config)
|
||||||
|
self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def compute_loss(self, logits, auxiliary_logits, labels):
|
||||||
|
# upsample logits to the images' original size
|
||||||
|
upsampled_logits = nn.functional.interpolate(
|
||||||
|
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
if auxiliary_logits is not None:
|
||||||
|
upsampled_auxiliary_logits = nn.functional.interpolate(
|
||||||
|
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
# compute weighted loss
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=255)
|
||||||
|
main_loss = loss_fct(upsampled_logits, labels)
|
||||||
|
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
|
||||||
|
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values=None,
|
||||||
|
head_mask=None,
|
||||||
|
labels=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
|
||||||
|
Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
|
||||||
|
(Cross-Entropy).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
|
||||||
|
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
|
||||||
|
>>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
|
||||||
|
|
||||||
|
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.beit(
|
||||||
|
pixel_values,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True, # we need the intermediate hidden states
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
|
||||||
|
|
||||||
|
# only keep certain features, and reshape
|
||||||
|
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
|
||||||
|
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
patch_resolution = self.config.image_size // self.config.patch_size
|
||||||
|
features = [
|
||||||
|
x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
|
||||||
|
]
|
||||||
|
|
||||||
|
# apply FPNs
|
||||||
|
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
||||||
|
for i in range(len(features)):
|
||||||
|
features[i] = ops[i](features[i])
|
||||||
|
|
||||||
|
logits = self.decode_head(features)
|
||||||
|
auxiliary_logits = None
|
||||||
|
if self.auxiliary_head is not None:
|
||||||
|
auxiliary_logits = self.auxiliary_head(features)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.num_labels == 1:
|
||||||
|
raise ValueError("The number of labels should be greater than one")
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
if output_hidden_states:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
else:
|
||||||
|
output = (logits,) + outputs[3:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -606,6 +606,11 @@ class BeitForMaskedImageModeling:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BeitForSemanticSegmentation:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class BeitModel:
|
class BeitModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import BeitConfig
|
from transformers import BeitConfig
|
||||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
@@ -31,7 +33,13 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel
|
from transformers import (
|
||||||
|
MODEL_MAPPING,
|
||||||
|
BeitForImageClassification,
|
||||||
|
BeitForMaskedImageModeling,
|
||||||
|
BeitForSemanticSegmentation,
|
||||||
|
BeitModel,
|
||||||
|
)
|
||||||
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +61,7 @@ class BeitModelTester:
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
hidden_size=32,
|
hidden_size=32,
|
||||||
num_hidden_layers=5,
|
num_hidden_layers=4,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
@@ -63,6 +71,7 @@ class BeitModelTester:
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
out_indices=[0, 1, 2, 3],
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.vocab_size = 100
|
self.vocab_size = 100
|
||||||
@@ -82,6 +91,7 @@ class BeitModelTester:
|
|||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.out_indices = out_indices
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
@@ -109,6 +119,7 @@ class BeitModelTester:
|
|||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
is_decoder=False,
|
is_decoder=False,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
|
out_indices=self.out_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
@@ -160,7 +171,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else ()
|
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -212,11 +225,14 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class in get_values(MODEL_MAPPING):
|
|
||||||
continue
|
|
||||||
# we don't test BeitForMaskedImageModeling
|
# we don't test BeitForMaskedImageModeling
|
||||||
if model_class.__name__ == "BeitForMaskedImageModeling":
|
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
||||||
continue
|
continue
|
||||||
|
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||||
|
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||||
|
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||||
|
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||||
|
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
@@ -233,11 +249,17 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
|
|
||||||
continue
|
|
||||||
# we don't test BeitForMaskedImageModeling
|
# we don't test BeitForMaskedImageModeling
|
||||||
if model_class.__name__ == "BeitForMaskedImageModeling":
|
if (
|
||||||
|
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
|
||||||
|
or not model_class.supports_gradient_checkpointing
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||||
|
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||||
|
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||||
|
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||||
|
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
@@ -298,7 +320,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
|
attentions = outputs.attentions
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -316,15 +339,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
self.assertEqual(out_len + 1, len(outputs))
|
||||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
|
||||||
elif self.is_encoder_decoder:
|
|
||||||
added_hidden_states = 2
|
|
||||||
else:
|
|
||||||
added_hidden_states = 1
|
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
|
||||||
|
|
||||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
self_attentions = outputs.attentions
|
||||||
|
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -472,3 +489,32 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected_class_idx = 2396
|
expected_class_idx = 2396
|
||||||
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
|
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_semantic_segmentation(self):
|
||||||
|
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
|
||||||
|
|
||||||
|
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
image = Image.open(ds[0]["file"])
|
||||||
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = torch.Size((1, 150, 160, 160))
|
||||||
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[
|
||||||
|
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||||
|
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||||
|
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||||
|
]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ if is_torch_fx_available():
|
|||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
configs_no_init = copy.deepcopy(config)
|
configs_no_init = copy.deepcopy(config)
|
||||||
for key in configs_no_init.__dict__.keys():
|
for key in configs_no_init.__dict__.keys():
|
||||||
if "_range" in key or "_std" in key or "initializer_factor" in key:
|
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||||
setattr(configs_no_init, key, 1e-10)
|
setattr(configs_no_init, key, 1e-10)
|
||||||
return configs_no_init
|
return configs_no_init
|
||||||
|
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||||||
# models to ignore for model xxx mapping
|
# models to ignore for model xxx mapping
|
||||||
"SegformerDecodeHead",
|
"SegformerDecodeHead",
|
||||||
"SegformerForSemanticSegmentation",
|
"SegformerForSemanticSegmentation",
|
||||||
|
"BeitForSemanticSegmentation",
|
||||||
"FlaxBeitForMaskedImageModeling",
|
"FlaxBeitForMaskedImageModeling",
|
||||||
"BeitForMaskedImageModeling",
|
"BeitForMaskedImageModeling",
|
||||||
"CLIPTextModel",
|
"CLIPTextModel",
|
||||||
|
|||||||
Reference in New Issue
Block a user