From 6dc884abc85c7f9b3b46f536a541d34d9f513d54 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 28 Nov 2022 20:33:49 +0100 Subject: [PATCH] [Maskformer] Add MaskFormerSwin backbone (#20344) * First draft * Fix backwards compatibility * More fixes * More fixes * Make backbone more general * Improve backbone * Improve test * Fix config checkpoint * Address comments * Use model_type * Address more comments * Fix special model names * Remove MaskFormerSwinModel and MaskFormerSwinPreTrainedModel from main init * Fix typo * Update backbone * Apply suggestion Co-authored-by: Niels Rogge --- docs/source/en/index.mdx | 1 + src/transformers/__init__.py | 6 +- .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/maskformer/__init__.py | 16 +- .../configuration_maskformer_swin.py | 149 +++ .../models/maskformer/modeling_maskformer.py | 865 +--------------- .../maskformer/modeling_maskformer_swin.py | 925 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 7 + .../test_modeling_maskformer_swin.py | 386 ++++++++ utils/check_copies.py | 3 +- utils/check_repo.py | 5 + 12 files changed, 1525 insertions(+), 843 deletions(-) create mode 100644 src/transformers/models/maskformer/configuration_maskformer_swin.py create mode 100644 src/transformers/models/maskformer/modeling_maskformer_swin.py create mode 100644 tests/models/maskformer/test_modeling_maskformer_swin.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 790ce8f4d1..4cba7ee969 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -288,6 +288,7 @@ Flax), PyTorch, and/or TensorFlow. | Marian | ✅ | ❌ | ✅ | ✅ | ✅ | | MarkupLM | ✅ | ✅ | ✅ | ❌ | ❌ | | MaskFormer | ❌ | ❌ | ✅ | ❌ | ❌ | +| MaskFormerSwin | ❌ | ❌ | ❌ | ❌ | ❌ | | mBART | ✅ | ✅ | ✅ | ✅ | ✅ | | Megatron-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 51e3430a5e..714ce483ea 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -295,7 +295,7 @@ _import_structure = { "MarkupLMProcessor", "MarkupLMTokenizer", ], - "models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"], + "models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"], "models.mbart": ["MBartConfig"], "models.mbart50": [], "models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"], @@ -1622,6 +1622,7 @@ else: "MaskFormerForInstanceSegmentation", "MaskFormerModel", "MaskFormerPreTrainedModel", + "MaskFormerSwinBackbone", ] ) _import_structure["models.markuplm"].extend( @@ -3479,7 +3480,7 @@ if TYPE_CHECKING: MarkupLMProcessor, MarkupLMTokenizer, ) - from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig + from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig from .models.mbart import MBartConfig from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig @@ -4573,6 +4574,7 @@ if TYPE_CHECKING: MaskFormerForInstanceSegmentation, MaskFormerModel, MaskFormerPreTrainedModel, + MaskFormerSwinBackbone, ) from .models.mbart import ( MBartForCausalLM, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0ff1dd5671..6f83aea53e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -98,6 +98,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("marian", "MarianConfig"), ("markuplm", "MarkupLMConfig"), ("maskformer", "MaskFormerConfig"), + ("maskformer-swin", "MaskFormerSwinConfig"), ("mbart", "MBartConfig"), ("mctct", "MCTCTConfig"), ("megatron-bert", "MegatronBertConfig"), @@ -395,6 +396,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("marian", "Marian"), ("markuplm", "MarkupLM"), ("maskformer", "MaskFormer"), + ("maskformer-swin", "MaskFormerSwin"), ("mbart", "mBART"), ("mbart50", "mBART-50"), ("mctct", "M-CTC-T"), @@ -491,6 +493,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( ("data2vec-text", "data2vec"), ("data2vec-vision", "data2vec"), ("donut-swin", "donut"), + ("maskformer-swin", "maskformer"), ("xclip", "x_clip"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7f9d11ce0d..62344b10f6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -97,6 +97,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("marian", "MarianModel"), ("markuplm", "MarkupLMModel"), ("maskformer", "MaskFormerModel"), + ("maskformer-swin", "MaskFormerSwinModel"), ("mbart", "MBartModel"), ("mctct", "MCTCTModel"), ("megatron-bert", "MegatronBertModel"), @@ -847,6 +848,7 @@ _MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( [ # Backbone mapping + ("maskformer-swin", "MaskFormerSwinBackbone"), ("resnet", "ResNetBackbone"), ] ) diff --git a/src/transformers/models/maskformer/__init__.py b/src/transformers/models/maskformer/__init__.py index 4234f76dc5..db2529ca24 100644 --- a/src/transformers/models/maskformer/__init__.py +++ b/src/transformers/models/maskformer/__init__.py @@ -20,7 +20,10 @@ from typing import TYPE_CHECKING from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available -_import_structure = {"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"]} +_import_structure = { + "configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"], + "configuration_maskformer_swin": ["MaskFormerSwinConfig"], +} try: if not is_vision_available(): @@ -43,9 +46,15 @@ else: "MaskFormerModel", "MaskFormerPreTrainedModel", ] + _import_structure["modeling_maskformer_swin"] = [ + "MaskFormerSwinBackbone", + "MaskFormerSwinModel", + "MaskFormerSwinPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig + from .configuration_maskformer_swin import MaskFormerSwinConfig try: if not is_vision_available(): @@ -66,6 +75,11 @@ if TYPE_CHECKING: MaskFormerModel, MaskFormerPreTrainedModel, ) + from .modeling_maskformer_swin import ( + MaskFormerSwinBackbone, + MaskFormerSwinModel, + MaskFormerSwinPreTrainedModel, + ) else: diff --git a/src/transformers/models/maskformer/configuration_maskformer_swin.py b/src/transformers/models/maskformer/configuration_maskformer_swin.py new file mode 100644 index 0000000000..2ba3cd3e20 --- /dev/null +++ b/src/transformers/models/maskformer/configuration_maskformer_swin.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MaskFormer Swin Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MaskFormerSwinConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate + a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Swin + [microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`List[int]`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`List[int]`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to True): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to False): + Whether or not to add absolute position embeddings to the patch embeddings. + patch_norm (`bool`, *optional*, defaults to True): + Whether or not to add layer normalization after patch embedding. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + out_features (`List[str]`, *optional*): + If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`. + + Example: + + ```python + >>> from transformers import MaskFormerSwinConfig, MaskFormerSwinModel + + >>> # Initializing a microsoft/swin-tiny-patch4-window7-224 style configuration + >>> configuration = MaskFormerSwinConfig() + + >>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration + >>> model = MaskFormerSwinModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "maskformer-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + patch_norm=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + out_features=None, + **kwargs + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.path_norm = patch_norm + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + if out_features is not None: + if not isinstance(out_features, list): + raise ValueError("out_features should be a list") + for feature in out_features: + if feature not in self.stage_names: + raise ValueError( + f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" + ) + self.out_features = out_features diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index dabbb84f89..01ad3e6641 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch MaskFormer model.""" -import collections.abc import math import random from dataclasses import dataclass @@ -25,12 +24,12 @@ import numpy as np import torch from torch import Tensor, nn +from transformers import AutoBackbone from transformers.utils import logging from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithCrossAttentions -from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, @@ -40,8 +39,8 @@ from ...utils import ( requires_backends, ) from ..detr import DetrConfig -from ..swin import SwinConfig from .configuration_maskformer import MaskFormerConfig +from .configuration_maskformer_swin import MaskFormerSwinConfig if is_scipy_available(): @@ -60,71 +59,6 @@ MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -@dataclass -class MaskFormerSwinModelOutputWithPooling(ModelOutput): - """ - Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state after a mean pooling operation. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): - A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to - `batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the - `forward` method. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MaskFormerSwinBaseModelOutput(ModelOutput): - """ - Class for SwinEncoder's outputs. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): - A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to - `batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward` - method. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - @dataclass # Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput class DetrDecoderOutput(BaseModelOutputWithCrossAttentions): @@ -471,714 +405,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = return loss / height_and_width -# Copied from transformers.models.swin.modeling_swin.window_partition -def window_partition(input_feature, window_size): - """ - Partitions the given input into windows. - """ - batch_size, height, width, num_channels = input_feature.shape - input_feature = input_feature.view( - batch_size, height // window_size, window_size, width // window_size, window_size, num_channels - ) - windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) - return windows - - -# Copied from transformers.models.swin.modeling_swin.window_reverse -def window_reverse(windows, window_size, height, width): - """ - Merges windows to produce higher resolution features. - """ - num_channels = windows.shape[-1] - windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) - windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) - return windows - - -# Copied from transformers.models.swin.modeling_swin.drop_path -def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) - random_tensor.floor_() # binarize - output = input.div(keep_prob) * random_tensor - return output - - -class MaskFormerSwinEmbeddings(nn.Module): - """ - Construct the patch and position embeddings. - """ - - def __init__(self, config): - super().__init__() - - self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config) - num_patches = self.patch_embeddings.num_patches - self.patch_grid = self.patch_embeddings.grid_size - - if config.use_absolute_embeddings: - self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) - else: - self.position_embeddings = None - - self.norm = nn.LayerNorm(config.embed_dim) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, pixel_values): - embeddings, output_dimensions = self.patch_embeddings(pixel_values) - embeddings = self.norm(embeddings) - - if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings, output_dimensions - - -class MaskFormerSwinPatchEmbeddings(nn.Module): - """ - Image to Patch Embedding, including padding. - """ - - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.embed_dim - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) - - self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - - def maybe_pad(self, pixel_values, height, width): - if width % self.patch_size[1] != 0: - pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) - pixel_values = nn.functional.pad(pixel_values, pad_values) - if height % self.patch_size[0] != 0: - pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) - pixel_values = nn.functional.pad(pixel_values, pad_values) - return pixel_values - - def forward(self, pixel_values): - _, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - # pad the input to be divisible by self.patch_size, if needed - pixel_values = self.maybe_pad(pixel_values, height, width) - embeddings = self.projection(pixel_values) - _, _, height, width = embeddings.shape - output_dimensions = (height, width) - embeddings_flat = embeddings.flatten(2).transpose(1, 2) - - return embeddings_flat, output_dimensions - - -class MaskFormerSwinPatchMerging(nn.Module): - """ - Patch Merging Layer for maskformer model. - - Args: - input_resolution (`Tuple[int]`): - Resolution of input feature. - dim (`int`): - Number of input channels. - norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): - Normalization layer class. - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def maybe_pad(self, input_feature, width, height): - should_pad = (height % 2 == 1) or (width % 2 == 1) - if should_pad: - pad_values = (0, 0, 0, width % 2, 0, height % 2) - input_feature = nn.functional.pad(input_feature, pad_values) - - return input_feature - - def forward(self, input_feature, input_dimensions): - height, width = input_dimensions - # `dim` is height * width - batch_size, dim, num_channels = input_feature.shape - - input_feature = input_feature.view(batch_size, height, width, num_channels) - # pad input to be disible by width and height, if needed - input_feature = self.maybe_pad(input_feature, height, width) - # [batch_size, height/2, width/2, num_channels] - input_feature_0 = input_feature[:, 0::2, 0::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_1 = input_feature[:, 1::2, 0::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_2 = input_feature[:, 0::2, 1::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_3 = input_feature[:, 1::2, 1::2, :] - # batch_size height/2 width/2 4*num_channels - input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) - input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C - - input_feature = self.norm(input_feature) - input_feature = self.reduction(input_feature) - - return input_feature - - -# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin -class MaskFormerSwinDropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob: Optional[float] = None) -> None: - super().__init__() - self.drop_prob = drop_prob - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return drop_path(x, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return "p={}".format(self.drop_prob) - - -# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin -class MaskFormerSwinSelfAttention(nn.Module): - def __init__(self, config, dim, num_heads, window_size): - super().__init__() - if dim % num_heads != 0: - raise ValueError( - f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" - ) - - self.num_attention_heads = num_heads - self.attention_head_size = int(dim / num_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.window_size = ( - window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) - ) - - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) - ) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) - self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) - self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 - ) - - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attention_scores = attention_scores + relative_position_bias.unsqueeze(0) - - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function) - mask_shape = attention_mask.shape[0] - attention_scores = attention_scores.view( - batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim - ) - attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) - attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin -class MaskFormerSwinSelfOutput(nn.Module): - def __init__(self, config, dim): - super().__init__() - self.dense = nn.Linear(dim, dim) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states - - -# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin -class MaskFormerSwinAttention(nn.Module): - def __init__(self, config, dim, num_heads, window_size): - super().__init__() - self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size) - self.output = MaskFormerSwinSelfOutput(config, dim) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin -class MaskFormerSwinIntermediate(nn.Module): - def __init__(self, config, dim): - super().__init__() - self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin -class MaskFormerSwinOutput(nn.Module): - def __init__(self, config, dim): - super().__init__() - self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - return hidden_states - - -class MaskFormerSwinLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.shift_size = shift_size - self.window_size = config.window_size - self.input_resolution = input_resolution - self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) - self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size) - self.drop_path = ( - MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() - ) - self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) - self.intermediate = MaskFormerSwinIntermediate(config, dim) - self.output = MaskFormerSwinOutput(config, dim) - - def get_attn_mask(self, input_resolution): - if self.shift_size > 0: - # calculate attention mask for SW-MSA - height, width = input_resolution - img_mask = torch.zeros((1, height, width, 1)) - height_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - width_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - count = 0 - for height_slice in height_slices: - for width_slice in width_slices: - img_mask[:, height_slice, width_slice, :] = count - count += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - return attn_mask - - def maybe_pad(self, hidden_states, height, width): - pad_left = pad_top = 0 - pad_rigth = (self.window_size - width % self.window_size) % self.window_size - pad_bottom = (self.window_size - height % self.window_size) % self.window_size - pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom) - hidden_states = nn.functional.pad(hidden_states, pad_values) - return hidden_states, pad_values - - def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): - height, width = input_dimensions - batch_size, dim, channels = hidden_states.size() - shortcut = hidden_states - - hidden_states = self.layernorm_before(hidden_states) - hidden_states = hidden_states.view(batch_size, height, width, channels) - # pad hidden_states to multiples of window size - hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) - - _, height_pad, width_pad, _ = hidden_states.shape - # cyclic shift - if self.shift_size > 0: - shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_hidden_states = hidden_states - - # partition windows - hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) - hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask((height_pad, width_pad)) - if attn_mask is not None: - attn_mask = attn_mask.to(hidden_states_windows.device) - - self_attention_outputs = self.attention( - hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions - ) - - attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) - shifted_windows = window_reverse( - attention_windows, self.window_size, height_pad, width_pad - ) # B height' width' C - - # reverse cyclic shift - if self.shift_size > 0: - attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - attention_windows = shifted_windows - - was_padded = pad_values[3] > 0 or pad_values[5] > 0 - if was_padded: - attention_windows = attention_windows[:, :height, :width, :].contiguous() - - attention_windows = attention_windows.view(batch_size, height * width, channels) - - hidden_states = shortcut + self.drop_path(attention_windows) - - layer_output = self.layernorm_after(hidden_states) - layer_output = self.intermediate(layer_output) - layer_output = hidden_states + self.output(layer_output) - - outputs = (layer_output,) + outputs - - return outputs - - -class MaskFormerSwinStage(nn.Module): - # Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin - def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): - super().__init__() - self.config = config - self.dim = dim - self.blocks = nn.ModuleList( - [ - MaskFormerSwinLayer( - config=config, - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - shift_size=0 if (i % 2 == 0) else config.window_size // 2, - ) - for i in range(depth) - ] - ) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) - else: - self.downsample = None - - self.pointing = False - - def forward( - self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False - ): - all_hidden_states = () if output_hidden_states else None - - height, width = input_dimensions - for i, block_module in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - - block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) - - hidden_states = block_hidden_states[0] - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.downsample is not None: - height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 - output_dimensions = (height, width, height_downsampled, width_downsampled) - hidden_states = self.downsample(hidden_states, input_dimensions) - else: - output_dimensions = (height, width, height, width) - - return hidden_states, output_dimensions, all_hidden_states - - -class MaskFormerSwinEncoder(nn.Module): - # Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin - def __init__(self, config, grid_size): - super().__init__() - self.num_layers = len(config.depths) - self.config = config - dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] - self.layers = nn.ModuleList( - [ - MaskFormerSwinStage( - config=config, - dim=int(config.embed_dim * 2**i_layer), - input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), - depth=config.depths[i_layer], - num_heads=config.num_heads[i_layer], - drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], - downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None, - ) - for i_layer in range(self.num_layers) - ] - ) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - input_dimensions, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - all_hidden_states = () if output_hidden_states else None - all_input_dimensions = () - all_self_attentions = () if output_attentions else None - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - for i, layer_module in enumerate(self.layers): - layer_head_mask = head_mask[i] if head_mask is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, layer_head_mask - ) - else: - layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - output_hidden_states, - ) - - input_dimensions = (output_dimensions[-2], output_dimensions[-1]) - all_input_dimensions += (input_dimensions,) - if output_hidden_states: - all_hidden_states += (layer_all_hidden_states,) - - hidden_states = layer_hidden_states - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - - return MaskFormerSwinBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - hidden_states_spatial_dimensions=all_input_dimensions, - attentions=all_self_attentions, - ) - - -class MaskFormerSwinModel(nn.Module, ModuleUtilsMixin): - def __init__(self, config, add_pooling_layer=True): - super().__init__() - self.config = config - self.num_layers = len(config.depths) - self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) - - self.embeddings = MaskFormerSwinEmbeddings(config) - self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid) - - self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) - self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None - - def get_input_embeddings(self): - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def forward( - self, - pixel_values=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - - embedding_output, input_dimensions = self.embeddings(pixel_values) - - encoder_outputs = self.encoder( - embedding_output, - input_dimensions, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = encoder_outputs.last_hidden_state - sequence_output = self.layernorm(sequence_output) - - pooled_output = None - if self.pooler is not None: - pooled_output = self.pooler(sequence_output.transpose(1, 2)) - pooled_output = torch.flatten(pooled_output, 1) - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions - - return MaskFormerSwinModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - hidden_states_spatial_dimensions=hidden_states_spatial_dimensions, - attentions=encoder_outputs.attentions, - ) - - # Copied from transformers.models.detr.modeling_detr.DetrAttention class DetrAttention(nn.Module): """ @@ -1923,52 +1149,6 @@ class MaskFormerLoss(nn.Module): return num_masks_pt -class MaskFormerSwinTransformerBackbone(nn.Module): - """ - This class uses [`MaskFormerSwinModel`] to reshape its `hidden_states` from (`batch_size, sequence_length, - hidden_size)` to (`batch_size, num_channels, height, width)`). - - Args: - config (`SwinConfig`): - The configuration used by [`MaskFormerSwinModel`]. - """ - - def __init__(self, config: SwinConfig): - super().__init__() - self.model = MaskFormerSwinModel(config) - self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(out_shape) for out_shape in self.outputs_shapes]) - - def forward(self, *args, **kwargs) -> List[Tensor]: - output = self.model(*args, **kwargs, output_hidden_states=True) - hidden_states_permuted: List[Tensor] = [] - # we need to reshape the hidden state to their original spatial dimensions - # skipping the embeddings - hidden_states: Tuple[Tuple[Tensor]] = output.hidden_states[1:] - # spatial dimensions contains all the heights and widths of each stage, including after the embeddings - spatial_dimensions: Tuple[Tuple[int, int]] = output.hidden_states_spatial_dimensions - for i, (hidden_state, (height, width)) in enumerate(zip(hidden_states, spatial_dimensions)): - norm = self.hidden_states_norms[i] - # the last element corespond to the layer's last block output but before patch merging - hidden_state_unpolled = hidden_state[-1] - hidden_state_norm = norm(hidden_state_unpolled) - # our pixel decoder (FPN) expect 3D tensors (features) - batch_size, _, hidden_size = hidden_state_norm.shape - # reshape our tensor "b (h w) d -> b d h w" - hidden_state_permuted = ( - hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous() - ) - hidden_states_permuted.append(hidden_state_permuted) - return hidden_states_permuted - - @property - def input_resolutions(self) -> List[int]: - return [layer.input_resolution for layer in self.model.encoder.layers] - - @property - def outputs_shapes(self) -> List[int]: - return [layer.dim for layer in self.model.encoder.layers] - - class MaskFormerFPNConvLayer(nn.Module): def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1): """ @@ -2065,7 +1245,7 @@ class MaskFormerPixelDecoder(nn.Module): def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs): """ Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic - Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's feature into a Feature Pyramid + Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's features into a Feature Pyramid Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`. Args: @@ -2075,13 +1255,15 @@ class MaskFormerPixelDecoder(nn.Module): The features (channels) of the target masks size \\C_{\epsilon}\\ in the paper. """ super().__init__() + self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs) self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1) def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput: - fpn_features: List[Tensor] = self.fpn(features) + fpn_features = self.fpn(features) # we use the last feature map last_feature_projected = self.mask_projection(fpn_features[-1]) + return MaskFormerPixelDecoderOutput( last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else () ) @@ -2193,17 +1375,26 @@ class MaskFormerPixelLevelModule(nn.Module): The configuration used to instantiate this model. """ super().__init__() - self.encoder = MaskFormerSwinTransformerBackbone(config.backbone_config) + + # TODD: add method to load pretrained weights of backbone + backbone_config = config.backbone_config + if backbone_config.model_type == "swin": + # for backwards compatibility + backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) + backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] + self.encoder = AutoBackbone.from_config(backbone_config) + + feature_channels = self.encoder.channels self.decoder = MaskFormerPixelDecoder( - in_features=self.encoder.outputs_shapes[-1], + in_features=feature_channels[-1], feature_size=config.fpn_feature_size, mask_feature_size=config.mask_feature_size, - lateral_widths=self.encoder.outputs_shapes[:-1], + lateral_widths=feature_channels[:-1], ) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput: - features: List[Tensor] = self.encoder(pixel_values) - decoder_output: MaskFormerPixelDecoderOutput = self.decoder(features, output_hidden_states) + features = self.encoder(pixel_values).feature_maps + decoder_output = self.decoder(features, output_hidden_states) return MaskFormerPixelLevelModuleOutput( # the last feature is actually the output from the last layer encoder_last_hidden_state=features[-1], @@ -2335,8 +1526,8 @@ class MaskFormerPreTrainedModel(PreTrainedModel): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing = value + if isinstance(module, MaskFormerPixelLevelModule): + module.encoder.gradient_checkpointing = value if isinstance(module, DetrDecoder): module.gradient_checkpointing = value @@ -2350,7 +1541,7 @@ class MaskFormerModel(MaskFormerPreTrainedModel): super().__init__(config) self.pixel_level_module = MaskFormerPixelLevelModule(config) self.transformer_module = MaskFormerTransformerModule( - in_features=self.pixel_level_module.encoder.outputs_shapes[-1], config=config + in_features=self.pixel_level_module.encoder.channels[-1], config=config ) self.post_init() @@ -2407,15 +1598,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel): if pixel_mask is None: pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) - pixel_level_module_output: MaskFormerPixelLevelModuleOutput = self.pixel_level_module( - pixel_values, output_hidden_states - ) + pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states) image_features = pixel_level_module_output.encoder_last_hidden_state pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state - transformer_module_output: DetrDecoderOutput = self.transformer_module( - image_features, output_hidden_states, output_attentions - ) + transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions) queries = transformer_module_output.last_hidden_state encoder_hidden_states = None diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py new file mode 100644 index 0000000000..08a09a289e --- /dev/null +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -0,0 +1,925 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MaskFormer Swin Transformer. The reason Swin Transformer is implemented here is because MaskFormer uses the hidden +states before downsampling, which is different from the default Swin Transformer.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...file_utils import ModelOutput +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from .configuration_maskformer_swin import MaskFormerSwinConfig + + +@dataclass +class MaskFormerSwinModelOutputWithPooling(ModelOutput): + """ + Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a mean pooling operation. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): + A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to + `batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the + `forward` method. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MaskFormerSwinBaseModelOutput(ModelOutput): + """ + Class for SwinEncoder's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): + A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to + `batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward` + method. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.drop_path +def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class MaskFormerSwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values): + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings +class MaskFormerSwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class MaskFormerSwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin +class MaskFormerSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin +class MaskFormerSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin +class MaskFormerSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin +class MaskFormerSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size) + self.output = MaskFormerSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin +class MaskFormerSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin +class MaskFormerSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class MaskFormerSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + super().__init__() + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size) + self.drop_path = ( + MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = MaskFormerSwinIntermediate(config, dim) + self.output = MaskFormerSwinOutput(config, dim) + + def get_attn_mask(self, input_resolution): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + height, width = input_resolution + img_mask = torch.zeros((1, height, width, 1)) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_left = pad_top = 0 + pad_rigth = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + height, width = input_dimensions + batch_size, dim, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask((height_pad, width_pad)) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) + + self_attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse( + attention_windows, self.window_size, height_pad, width_pad + ) # B height' width' C + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + outputs = (layer_output,) + outputs + + return outputs + + +class MaskFormerSwinStage(nn.Module): + # Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + MaskFormerSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False + ): + all_hidden_states = () if output_hidden_states else None + + height, width = input_dimensions + for i, block_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = block_hidden_states[0] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + return hidden_states, output_dimensions, all_hidden_states + + +class MaskFormerSwinEncoder(nn.Module): + # Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + self.layers = nn.ModuleList( + [ + MaskFormerSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + input_dimensions, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_input_dimensions = () + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, layer_head_mask + ) + else: + layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + output_hidden_states, + ) + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + if output_hidden_states: + all_hidden_states += (layer_all_hidden_states,) + + hidden_states = layer_hidden_states + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return MaskFormerSwinBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + hidden_states_spatial_dimensions=all_input_dimensions, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->MaskFormerSwin, swin->model +class MaskFormerSwinPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MaskFormerSwinConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MaskFormerSwinEncoder): + module.gradient_checkpointing = value + + +class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = MaskFormerSwinEmbeddings(config) + self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions + + return MaskFormerSwinModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + hidden_states_spatial_dimensions=hidden_states_spatial_dimensions, + attentions=encoder_outputs.attentions, + ) + + +class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel): + """ + MaskFormerSwin backbone, designed especially for the MaskFormer framework. + + This classes reshapes `hidden_states` from (`batch_size, sequence_length, hidden_size)` to (`batch_size, + num_channels, height, width)`). It also adds additional layernorms after each stage. + + Args: + config (`MaskFormerSwinConfig`): + The configuration used by [`MaskFormerSwinModel`]. + """ + + def __init__(self, config: MaskFormerSwinConfig): + super().__init__(config) + + self.stage_names = config.stage_names + self.model = MaskFormerSwinModel(config) + + self.out_features = config.out_features + if "stem" in self.out_features: + raise ValueError("This backbone does not support 'stem' in the `out_features`.") + + num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + self.out_feature_channels = {} + for i, stage in enumerate(self.stage_names[1:]): + self.out_feature_channels[stage] = num_features[i] + + self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) + + @property + def channels(self): + return [self.out_feature_channels[name] for name in self.out_features] + + def forward( + self, + pixel_values: Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.model( + pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True + ) + + # we skip the stem + hidden_states = outputs.hidden_states[1:] + + feature_maps = () + # we need to reshape the hidden states to their original spatial dimensions + # spatial dimensions contains all the heights and widths of each stage, including after the embeddings + spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions + for i, (hidden_state, stage, (height, width)) in enumerate( + zip(hidden_states, self.stage_names[1:], spatial_dimensions) + ): + norm = self.hidden_states_norms[i] + # the last element corespond to the layer's last block output but before patch merging + hidden_state_unpolled = hidden_state[-1] + hidden_state_norm = norm(hidden_state_unpolled) + # the pixel decoder (FPN) expects 3D tensors (features) + batch_size, _, hidden_size = hidden_state_norm.shape + # reshape "b (h w) d -> b d h w" + hidden_state_permuted = ( + hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous() + ) + if stage in self.out_features: + feature_maps += (hidden_state_permuted,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + if output_attentions: + output += (outputs.attentions,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index a0778af99d..f3d937ac5d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3361,6 +3361,13 @@ class MaskFormerPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MaskFormerSwinBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MBartForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/maskformer/test_modeling_maskformer_swin.py b/tests/models/maskformer/test_modeling_maskformer_swin.py new file mode 100644 index 0000000000..95ee8263bc --- /dev/null +++ b/tests/models/maskformer/test_modeling_maskformer_swin.py @@ -0,0 +1,386 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch MaskFormer Swin model. """ + +import collections +import inspect +import unittest +from typing import Dict, List, Tuple + +from transformers import MaskFormerSwinConfig +from transformers.testing_utils import require_torch, torch_device +from transformers.utils import is_torch_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import MaskFormerSwinBackbone + from transformers.models.maskformer import MaskFormerSwinModel + + +class MaskFormerSwinModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + patch_size=2, + num_channels=3, + embed_dim=16, + depths=[1, 2, 1], + num_heads=[2, 2, 4], + window_size=2, + mlp_ratio=2.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + patch_norm=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + is_training=True, + scope=None, + use_labels=True, + type_sequence_label_size=10, + encoder_stride=8, + out_features=["stage1", "stage2", "stage3"], + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.patch_norm = patch_norm + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.is_training = is_training + self.scope = scope + self.use_labels = use_labels + self.type_sequence_label_size = type_sequence_label_size + self.encoder_stride = encoder_stride + self.out_features = out_features + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return MaskFormerSwinConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + embed_dim=self.embed_dim, + depths=self.depths, + num_heads=self.num_heads, + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + drop_path_rate=self.drop_path_rate, + hidden_act=self.hidden_act, + use_absolute_embeddings=self.use_absolute_embeddings, + path_norm=self.patch_norm, + layer_norm_eps=self.layer_norm_eps, + initializer_range=self.initializer_range, + encoder_stride=self.encoder_stride, + out_features=self.out_features, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = MaskFormerSwinModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1)) + expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1)) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + + def create_and_check_backbone(self, config, pixel_values, labels): + model = MaskFormerSwinBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [13, 16, 16, 16]) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + self.parent.assertListEqual(model.channels, [16, 32, 64]) + + # verify ValueError + with self.parent.assertRaises(ValueError): + config.out_features = ["stem"] + model = MaskFormerSwinBackbone(config=config) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class MaskFormerSwinModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + MaskFormerSwinModel, + MaskFormerSwinBackbone, + ) + if is_torch_available() + else () + ) + fx_compatible = False + test_torchscript = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = MaskFormerSwinModelTester(self) + self.config_tester = ConfigTester(self, config_class=MaskFormerSwinConfig, embed_dim=37) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + + @unittest.skip("Swin does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip("Swin does not support feedforward chunking") + def test_feed_forward_chunking(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @unittest.skip(reason="MaskFormerSwin is only used as backbone and doesn't support output_attentions") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="MaskFormerSwin is only used as an internal backbone") + def test_save_load_fast_init_to_base(self): + pass + + def check_hidden_states_output(self, inputs_dict, config, model_class, image_size): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # Swin has a different seq_length + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + + def test_hidden_states_output(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + self.check_hidden_states_output(inputs_dict, config, model_class, image_size) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + self.check_hidden_states_output(inputs_dict, config, model_class, image_size) + + def test_hidden_states_output_with_padding(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.patch_size = 3 + + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + + padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) + padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) + + @unittest.skip(reason="MaskFormerSwin doesn't have pretrained checkpoints") + def test_model_from_pretrained(self): + pass + + @unittest.skip(reason="This will be fixed once MaskFormerSwin is replaced by native Swin") + def test_initialization(self): + pass + + @unittest.skip(reason="This will be fixed once MaskFormerSwin is replaced by native Swin") + def test_gradient_checkpointing_backward_compatibility(self): + pass + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) diff --git a/utils/check_copies.py b/utils/check_copies.py index f5d1b1c416..728a650942 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -492,8 +492,9 @@ SPECIAL_MODEL_NAMES = { "Data2VecAudio": "Data2Vec", "Data2VecText": "Data2Vec", "Data2VecVision": "Data2Vec", - "DonutSwin": "Donut", + "DonutSwin": "Swin Transformer", "Marian": "MarianMT", + "MaskFormerSwin": "Swin Transformer", "OpenAI GPT-2": "GPT-2", "OpenAI GPT": "GPT", "Perceiver": "Perceiver IO", diff --git a/utils/check_repo.py b/utils/check_repo.py index a8ad36b385..494a703f7e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -41,6 +41,8 @@ PRIVATE_MODELS = [ "T5Stack", "SwitchTransformersStack", "TFDPRSpanPredictor", + "MaskFormerSwinModel", + "MaskFormerSwinPreTrainedModel", ] # Update this list for models that are not tested with a comment explaining the reason it should not be. @@ -668,8 +670,11 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ "PyTorchBenchmarkArguments", "TensorFlowBenchmark", "TensorFlowBenchmarkArguments", + "MaskFormerSwinBackbone", "ResNetBackbone", "AutoBackbone", + "MaskFormerSwinConfig", + "MaskFormerSwinModel", ]