Improve vision models (#17731)
* Improve vision models * Add a lot of improvements * Remove to_2tuple from swin tests * Fix TF Swin * Fix more tests * Fix copies * Improve more models * Fix ViTMAE test * Add channel check for TF models * Add proper channel check for TF models * Apply suggestion from code review * Apply suggestions from code review * Add channel check for Flax models, apply suggestion * Fix bug * Add tests for greyscale images * Add test for interpolation of pos encodigns Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -91,17 +91,7 @@ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Inspired by
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
|
||||||
# From PyTorch internals
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
|
|
||||||
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
@@ -112,16 +102,16 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
|
|||||||
argument.
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
class BeitDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
@@ -151,12 +141,7 @@ class BeitEmbeddings(nn.Module):
|
|||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
else:
|
else:
|
||||||
self.mask_token = None
|
self.mask_token = None
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = BeitPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
if config.use_absolute_position_embeddings:
|
if config.use_absolute_position_embeddings:
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
@@ -184,38 +169,43 @@ class BeitEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class BeitPatchEmbeddings(nn.Module):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class PatchEmbeddings(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config):
|
||||||
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.patch_shape = patch_shape
|
self.patch_shape = patch_shape
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
# FIXME look at relaxing size constraints
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
)
|
)
|
||||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
return x
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class BeitSelfAttention(nn.Module):
|
class BeitSelfAttention(nn.Module):
|
||||||
@@ -393,7 +383,7 @@ class BeitLayer(nn.Module):
|
|||||||
self.intermediate = BeitIntermediate(config)
|
self.intermediate = BeitIntermediate(config)
|
||||||
self.output = BeitOutput(config)
|
self.output = BeitOutput(config)
|
||||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||||
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
init_values = config.layer_scale_init_value
|
init_values = config.layer_scale_init_value
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ class FlaxBeitPatchEmbeddings(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
self.num_channels = self.config.num_channels
|
||||||
image_size = self.config.image_size
|
image_size = self.config.image_size
|
||||||
patch_size = self.config.patch_size
|
patch_size = self.config.patch_size
|
||||||
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
||||||
@@ -187,6 +188,11 @@ class FlaxBeitPatchEmbeddings(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, pixel_values):
|
def __call__(self, pixel_values):
|
||||||
|
num_channels = pixel_values.shape[-1]
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
batch_size, _, _, channels = embeddings.shape
|
batch_size, _, _, channels = embeddings.shape
|
||||||
return jnp.reshape(embeddings, (batch_size, -1, channels))
|
return jnp.reshape(embeddings, (batch_size, -1, channels))
|
||||||
@@ -603,7 +609,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = (1, config.image_size, config.image_size, 3)
|
input_shape = (1, config.image_size, config.image_size, config.num_channels)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
|
|||||||
@@ -53,36 +53,41 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Stochastic depth implementation
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||||||
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
|
||||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
|
|
||||||
Connect' is a different form of dropout in a separate paper... See discussion:
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
|
||||||
class ConvNextDropPath(nn.Module):
|
class ConvNextDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
class ConvNextLayerNorm(nn.Module):
|
class ConvNextLayerNorm(nn.Module):
|
||||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||||
@@ -122,8 +127,14 @@ class ConvNextEmbeddings(nn.Module):
|
|||||||
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
|
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
|
||||||
)
|
)
|
||||||
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
|
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
num_channels = pixel_values.shape[1]
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
embeddings = self.patch_embeddings(pixel_values)
|
embeddings = self.patch_embeddings(pixel_values)
|
||||||
embeddings = self.layernorm(embeddings)
|
embeddings = self.layernorm(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from transformers import shape_list
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
@@ -77,11 +79,18 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
|
|||||||
bias_initializer="zeros",
|
bias_initializer="zeros",
|
||||||
)
|
)
|
||||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
|
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
|
||||||
def call(self, pixel_values):
|
def call(self, pixel_values):
|
||||||
if isinstance(pixel_values, dict):
|
if isinstance(pixel_values, dict):
|
||||||
pixel_values = pixel_values["pixel_values"]
|
pixel_values = pixel_values["pixel_values"]
|
||||||
|
|
||||||
|
num_channels = shape_list(pixel_values)[1]
|
||||||
|
if tf.executing_eagerly() and num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
|
|
||||||
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
|
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
|
||||||
# So change the input format from `NCHW` to `NHWC`.
|
# So change the input format from `NCHW` to `NHWC`.
|
||||||
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
||||||
|
|||||||
@@ -78,36 +78,41 @@ class BaseModelOutputWithCLSToken(ModelOutput):
|
|||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.convnext.modeling_convnext.drop_path
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
|
|
||||||
Connect' is a different form of dropout in a separate paper... See discussion:
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
||||||
class CvtDropPath(nn.Module):
|
class CvtDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
class CvtEmbeddings(nn.Module):
|
class CvtEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -91,18 +91,8 @@ class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Inspired by
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
|
||||||
# From PyTorch internals
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
|
|
||||||
# Copied from transformers.models.beit.modeling_beit.drop_path
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||||||
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
@@ -113,17 +103,17 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
|
|||||||
argument.
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.beit.modeling_beit.DropPath
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
|
||||||
class DropPath(nn.Module):
|
class Data2VecVisionDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
@@ -137,8 +127,6 @@ class DropPath(nn.Module):
|
|||||||
return "p={}".format(self.drop_prob)
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
|
# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
|
||||||
class Data2VecVisionEmbeddings(nn.Module):
|
class Data2VecVisionEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -154,12 +142,7 @@ class Data2VecVisionEmbeddings(nn.Module):
|
|||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
else:
|
else:
|
||||||
self.mask_token = None
|
self.mask_token = None
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
if config.use_absolute_position_embeddings:
|
if config.use_absolute_position_embeddings:
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
@@ -187,39 +170,44 @@ class Data2VecVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
class Data2VecVisionPatchEmbeddings(nn.Module):
|
||||||
# Copied from transformers.models.beit.modeling_beit.PatchEmbeddings
|
|
||||||
class PatchEmbeddings(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config):
|
||||||
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.patch_shape = patch_shape
|
self.patch_shape = patch_shape
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
# FIXME look at relaxing size constraints
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
)
|
)
|
||||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
return x
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
|
# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
|
||||||
@@ -405,7 +393,7 @@ class Data2VecVisionLayer(nn.Module):
|
|||||||
self.intermediate = Data2VecVisionIntermediate(config)
|
self.intermediate = Data2VecVisionIntermediate(config)
|
||||||
self.output = Data2VecVisionOutput(config)
|
self.output = Data2VecVisionOutput(config)
|
||||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||||
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
init_values = config.layer_scale_init_value
|
init_values = config.layer_scale_init_value
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class TFDropPath(tf.keras.layers.Layer):
|
class TFData2VecVisionDropPath(tf.keras.layers.Layer):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
References:
|
References:
|
||||||
(1) github.com:rwightman/pytorch-image-models
|
(1) github.com:rwightman/pytorch-image-models
|
||||||
@@ -120,8 +120,6 @@ class TFDropPath(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
|
class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
|
||||||
"""
|
"""
|
||||||
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
||||||
@@ -132,9 +130,7 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.patch_embeddings = TFPatchEmbeddings(
|
self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings")
|
||||||
config=config, image_size=config.image_size, patch_size=config.patch_size, name="patch_embeddings"
|
|
||||||
)
|
|
||||||
self.num_patches = self.patch_embeddings.num_patches
|
self.num_patches = self.patch_embeddings.num_patches
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -192,40 +188,32 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class TFData2VecVisionPatchEmbeddings(tf.keras.layers.Layer):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class TFPatchEmbeddings(tf.keras.layers.Layer):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
Image to Patch Embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Data2VecVisionConfig, image_size: int = 224, patch_size: int = 16, **kwargs):
|
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
image_size = (
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
config.image_size
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
if isinstance(config.image_size, collections.abc.Iterable)
|
|
||||||
else (config.image_size, config.image_size)
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
)
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
patch_size = (
|
|
||||||
config.patch_size
|
|
||||||
if isinstance(config.patch_size, collections.abc.Iterable)
|
|
||||||
else (config.patch_size, config.patch_size)
|
|
||||||
)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.patch_shape = patch_shape
|
self.patch_shape = patch_shape
|
||||||
self.num_channels = config.num_channels
|
self.num_channels = num_channels
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
|
|
||||||
self.projection = tf.keras.layers.Conv2D(
|
self.projection = tf.keras.layers.Conv2D(
|
||||||
filters=self.embed_dim,
|
filters=hidden_size,
|
||||||
kernel_size=self.patch_size,
|
kernel_size=patch_size,
|
||||||
strides=self.patch_size,
|
strides=patch_size,
|
||||||
padding="valid",
|
padding="valid",
|
||||||
data_format="channels_last",
|
data_format="channels_last",
|
||||||
kernel_initializer="glorot_uniform", # following torch.nn.Linear
|
kernel_initializer="glorot_uniform", # following torch.nn.Linear
|
||||||
@@ -235,7 +223,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||||
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
|
if tf.executing_eagerly():
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
||||||
|
" configuration."
|
||||||
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
@@ -465,7 +458,7 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer):
|
|||||||
# Using `layers.Activation` instead of `tf.identity` to better control `training`
|
# Using `layers.Activation` instead of `tf.identity` to better control `training`
|
||||||
# behaviour.
|
# behaviour.
|
||||||
self.drop_path = (
|
self.drop_path = (
|
||||||
TFDropPath(drop_path_rate, name="drop_path")
|
TFData2VecVisionDropPath(drop_path_rate, name="drop_path")
|
||||||
if drop_path_rate > 0.0
|
if drop_path_rate > 0.0
|
||||||
else tf.keras.layers.Activation("linear", name="drop_path")
|
else tf.keras.layers.Activation("linear", name="drop_path")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -61,21 +61,9 @@ DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
|
|
||||||
|
|
||||||
class DeiTEmbeddings(nn.Module):
|
class DeiTEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
|
Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
|
def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
|
||||||
@@ -84,22 +72,17 @@ class DeiTEmbeddings(nn.Module):
|
|||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = DeiTPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||||
embeddings = self.patch_embeddings(pixel_values)
|
embeddings = self.patch_embeddings(pixel_values)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_length, _ = embeddings.size()
|
||||||
|
|
||||||
if bool_masked_pos is not None:
|
if bool_masked_pos is not None:
|
||||||
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
||||||
# replace the masked visual tokens by mask_tokens
|
# replace the masked visual tokens by mask_tokens
|
||||||
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||||
@@ -112,32 +95,34 @@ class DeiTEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbeddings(nn.Module):
|
class DeiTPatchEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config):
|
||||||
self,
|
|
||||||
image_size: int = 224,
|
|
||||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
|
||||||
num_channels: int = 3,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
# FIXME look at relaxing size constraints
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
@@ -483,7 +468,7 @@ class DeiTModel(DeiTPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> PatchEmbeddings:
|
def get_input_embeddings(self) -> DeiTPatchEmbeddings:
|
||||||
return self.embeddings.patch_embeddings
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -570,8 +555,8 @@ class DeiTPooler(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"DeiT Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
|
"DeiT Model with a decoder on top for masked image modeling, as proposed in"
|
||||||
" <https://arxiv.org/abs/2111.09886>`__.",
|
" [SimMIM](https://arxiv.org/abs/2111.09886).",
|
||||||
DEIT_START_DOCSTRING,
|
DEIT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
||||||
@@ -581,7 +566,11 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
|||||||
self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
|
self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
|
||||||
|
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
|
nn.Conv2d(
|
||||||
|
in_channels=config.hidden_size,
|
||||||
|
out_channels=config.encoder_stride**2 * config.num_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
),
|
||||||
nn.PixelShuffle(config.encoder_stride),
|
nn.PixelShuffle(config.encoder_stride),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -65,13 +65,6 @@ DPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
class DPTViTEmbeddings(nn.Module):
|
class DPTViTEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Construct the CLS token, position and patch embeddings.
|
Construct the CLS token, position and patch embeddings.
|
||||||
@@ -82,12 +75,7 @@ class DPTViTEmbeddings(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.patch_embeddings = DPTViTPatchEmbeddings(
|
self.patch_embeddings = DPTViTPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
@@ -138,19 +126,27 @@ class DPTViTPatchEmbeddings(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -54,21 +54,23 @@ GLPN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.segformer.modeling_segformer.drop_path
|
# Copied from transformers.models.segformer.modeling_segformer.drop_path
|
||||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
|
|
||||||
Connect' is a different form of dropout in a separate paper... See discussion:
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -76,13 +78,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
|||||||
class GLPNDropPath(nn.Module):
|
class GLPNDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
|
||||||
class GLPNOverlapPatchEmbeddings(nn.Module):
|
class GLPNOverlapPatchEmbeddings(nn.Module):
|
||||||
|
|||||||
@@ -126,8 +126,14 @@ class LevitPatchEmbeddings(nn.Module):
|
|||||||
self.embedding_layer_4 = LevitConvEmbeddings(
|
self.embedding_layer_4 = LevitConvEmbeddings(
|
||||||
config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
|
config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
|
||||||
)
|
)
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
|
num_channels = pixel_values.shape[1]
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
embeddings = self.embedding_layer_1(pixel_values)
|
embeddings = self.embedding_layer_1(pixel_values)
|
||||||
embeddings = self.activation_layer_1(embeddings)
|
embeddings = self.activation_layer_1(embeddings)
|
||||||
embeddings = self.embedding_layer_2(embeddings)
|
embeddings = self.embedding_layer_2(embeddings)
|
||||||
|
|||||||
@@ -471,13 +471,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
|
|||||||
return loss / height_and_width
|
return loss / height_and_width
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.swin.modeling_swin.window_partition
|
# Copied from transformers.models.swin.modeling_swin.window_partition
|
||||||
def window_partition(input_feature, window_size):
|
def window_partition(input_feature, window_size):
|
||||||
"""
|
"""
|
||||||
@@ -506,15 +499,21 @@ def window_reverse(windows, window_size, height, width):
|
|||||||
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
|
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return input
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = input.new_empty(shape).bernoulli_(keep_prob)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
if keep_prob > 0.0 and scale_by_keep:
|
random_tensor.floor_() # binarize
|
||||||
random_tensor.div_(keep_prob)
|
output = input.div(keep_prob) * random_tensor
|
||||||
return input * random_tensor
|
return output
|
||||||
|
|
||||||
|
|
||||||
class MaskFormerSwinEmbeddings(nn.Module):
|
class MaskFormerSwinEmbeddings(nn.Module):
|
||||||
@@ -525,12 +524,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(
|
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.embed_dim,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.patch_grid = self.patch_embeddings.grid_size
|
self.patch_grid = self.patch_embeddings.grid_size
|
||||||
|
|
||||||
@@ -559,17 +553,21 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
|||||||
Image to Patch Embedding, including padding.
|
Image to Patch Embedding, including padding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.embed_dim
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def maybe_pad(self, pixel_values, height, width):
|
def maybe_pad(self, pixel_values, height, width):
|
||||||
if width % self.patch_size[1] != 0:
|
if width % self.patch_size[1] != 0:
|
||||||
@@ -581,7 +579,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
|||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
_, _, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
@@ -649,13 +651,15 @@ class MaskFormerSwinPatchMerging(nn.Module):
|
|||||||
class MaskFormerSwinDropPath(nn.Module):
|
class MaskFormerSwinDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None, scale_by_keep=True):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super(MaskFormerSwinDropPath, self).__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
self.scale_by_keep = scale_by_keep
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(input, self.drop_prob, self.training, self.scale_by_keep)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
|
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
|
||||||
@@ -670,7 +674,10 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|||||||
self.num_attention_heads = num_heads
|
self.num_attention_heads = num_heads
|
||||||
self.attention_head_size = int(dim / num_heads)
|
self.attention_head_size = int(dim / num_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
self.window_size = to_2tuple(config.window_size)
|
window_size = config.window_size
|
||||||
|
self.window_size = (
|
||||||
|
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||||||
|
)
|
||||||
|
|
||||||
self.relative_position_bias_table = nn.Parameter(
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
||||||
|
|||||||
@@ -50,40 +50,41 @@ POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||||||
def to_2tuple(x):
|
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
|
||||||
if isinstance(x, collections.abc.Iterable):
|
"""
|
||||||
return x
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion:
|
argument.
|
||||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
|
||||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->PoolFormer
|
||||||
class PoolFormerDropPath(nn.Module):
|
class PoolFormerDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
class PoolFormerEmbeddings(nn.Module):
|
class PoolFormerEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -92,17 +93,17 @@ class PoolFormerEmbeddings(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None):
|
def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
patch_size = to_2tuple(patch_size)
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
stride = to_2tuple(stride)
|
stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
|
||||||
padding = to_2tuple(padding)
|
padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding)
|
||||||
self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity()
|
self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
x = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
x = self.norm(x)
|
embeddings = self.norm(embeddings)
|
||||||
return x
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class PoolFormerGroupNorm(nn.GroupNorm):
|
class PoolFormerGroupNorm(nn.GroupNorm):
|
||||||
|
|||||||
@@ -93,9 +93,15 @@ class RegNetEmbeddings(nn.Module):
|
|||||||
self.embedder = RegNetConvLayer(
|
self.embedder = RegNetConvLayer(
|
||||||
config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
|
config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
|
||||||
)
|
)
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, pixel_values):
|
||||||
hidden_state = self.embedder(hidden_state)
|
num_channels = pixel_values.shape[1]
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
|
hidden_state = self.embedder(pixel_values)
|
||||||
return hidden_state
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,9 +81,15 @@ class ResNetEmbeddings(nn.Module):
|
|||||||
config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
|
config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
|
||||||
)
|
)
|
||||||
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, pixel_values: Tensor) -> Tensor:
|
||||||
embedding = self.embedder(input)
|
num_channels = pixel_values.shape[1]
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
|
embedding = self.embedder(pixel_values)
|
||||||
embedding = self.pooler(embedding)
|
embedding = self.pooler(embedding)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
@@ -107,7 +113,7 @@ class ResNetShortCut(nn.Module):
|
|||||||
|
|
||||||
class ResNetBasicLayer(nn.Module):
|
class ResNetBasicLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
A classic ResNet's residual layer composed by a two `3x3` convolutions.
|
A classic ResNet's residual layer composed by two `3x3` convolutions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
|
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
|
||||||
@@ -133,10 +139,10 @@ class ResNetBasicLayer(nn.Module):
|
|||||||
|
|
||||||
class ResNetBottleNeckLayer(nn.Module):
|
class ResNetBottleNeckLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
A classic ResNet's bottleneck layer composed by a three `3x3` convolutions.
|
A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
|
||||||
|
|
||||||
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
|
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
|
||||||
convolution faster. The last `1x1` convolution remap the reduced features to `out_channels`.
|
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -86,21 +86,23 @@ class SegFormerImageClassifierOutput(ImageClassifierOutput):
|
|||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.convnext.modeling_convnext.drop_path
|
# Copied from transformers.models.convnext.modeling_convnext.drop_path
|
||||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
|
def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
|
|
||||||
Connect' is a different form of dropout in a separate paper... See discussion:
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -108,13 +110,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=T
|
|||||||
class SegformerDropPath(nn.Module):
|
class SegformerDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
class SegformerOverlapPatchEmbeddings(nn.Module):
|
class SegformerOverlapPatchEmbeddings(nn.Module):
|
||||||
"""Construct the overlapping patch embeddings."""
|
"""Construct the overlapping patch embeddings."""
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
# See all Swin models at https://huggingface.co/models?filter=swin
|
# See all Swin models at https://huggingface.co/models?filter=swin
|
||||||
]
|
]
|
||||||
|
|
||||||
# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
|
# drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -203,13 +203,6 @@ class SwinImageClassifierOutput(ModelOutput):
|
|||||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
def window_partition(input_feature, window_size):
|
def window_partition(input_feature, window_size):
|
||||||
"""
|
"""
|
||||||
Partitions the given input into windows.
|
Partitions the given input into windows.
|
||||||
@@ -232,20 +225,6 @@ def window_reverse(windows, window_size, height, width):
|
|||||||
return windows
|
return windows
|
||||||
|
|
||||||
|
|
||||||
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
|
|
||||||
"""
|
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
||||||
"""
|
|
||||||
if drop_prob == 0.0 or not training:
|
|
||||||
return input
|
|
||||||
keep_prob = 1 - drop_prob
|
|
||||||
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
||||||
random_tensor = input.new_empty(shape).bernoulli_(keep_prob)
|
|
||||||
if keep_prob > 0.0 and scale_by_keep:
|
|
||||||
random_tensor.div_(keep_prob)
|
|
||||||
return input * random_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class SwinEmbeddings(nn.Module):
|
class SwinEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Construct the patch and position embeddings. Optionally, also the mask token.
|
Construct the patch and position embeddings. Optionally, also the mask token.
|
||||||
@@ -254,12 +233,7 @@ class SwinEmbeddings(nn.Module):
|
|||||||
def __init__(self, config, use_mask_token=False):
|
def __init__(self, config, use_mask_token=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.patch_embeddings = SwinPatchEmbeddings(
|
self.patch_embeddings = SwinPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.embed_dim,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.patch_grid = self.patch_embeddings.grid_size
|
self.patch_grid = self.patch_embeddings.grid_size
|
||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
|
||||||
@@ -295,20 +269,25 @@ class SwinEmbeddings(nn.Module):
|
|||||||
|
|
||||||
class SwinPatchEmbeddings(nn.Module):
|
class SwinPatchEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.embed_dim
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def maybe_pad(self, pixel_values, height, width):
|
def maybe_pad(self, pixel_values, height, width):
|
||||||
if width % self.patch_size[1] != 0:
|
if width % self.patch_size[1] != 0:
|
||||||
@@ -320,7 +299,11 @@ class SwinPatchEmbeddings(nn.Module):
|
|||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
_, _, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
@@ -385,16 +368,40 @@ class SwinPatchMerging(nn.Module):
|
|||||||
return input_feature
|
return input_feature
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||||||
|
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
|
||||||
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
|
"""
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return input
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = input.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
|
||||||
class SwinDropPath(nn.Module):
|
class SwinDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None, scale_by_keep=True):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super(SwinDropPath, self).__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
self.scale_by_keep = scale_by_keep
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(input, self.drop_prob, self.training, self.scale_by_keep)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
class SwinSelfAttention(nn.Module):
|
class SwinSelfAttention(nn.Module):
|
||||||
@@ -408,7 +415,10 @@ class SwinSelfAttention(nn.Module):
|
|||||||
self.num_attention_heads = num_heads
|
self.num_attention_heads = num_heads
|
||||||
self.attention_head_size = int(dim / num_heads)
|
self.attention_head_size = int(dim / num_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
self.window_size = to_2tuple(config.window_size)
|
window_size = config.window_size
|
||||||
|
self.window_size = (
|
||||||
|
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||||||
|
)
|
||||||
|
|
||||||
self.relative_position_bias_table = nn.Parameter(
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
||||||
@@ -997,8 +1007,8 @@ class SwinModel(SwinPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
|
"Swin Model with a decoder on top for masked image modeling, as proposed in"
|
||||||
" <https://arxiv.org/abs/2111.09886>`__.",
|
" [SimMIM](https://arxiv.org/abs/2111.09886).",
|
||||||
SWIN_START_DOCSTRING,
|
SWIN_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
||||||
@@ -1009,7 +1019,9 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||||||
|
|
||||||
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
|
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv2d(in_channels=num_features, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
|
nn.Conv2d(
|
||||||
|
in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
|
||||||
|
),
|
||||||
nn.PixelShuffle(config.encoder_stride),
|
nn.PixelShuffle(config.encoder_stride),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
# See all Swin models at https://huggingface.co/models?filter=swin
|
# See all Swin models at https://huggingface.co/models?filter=swin
|
||||||
]
|
]
|
||||||
|
|
||||||
# to_2tuple, drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow
|
# drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow
|
||||||
# implementations of PyTorch functionalities in the timm library.
|
# implementations of PyTorch functionalities in the timm library.
|
||||||
|
|
||||||
|
|
||||||
@@ -208,13 +208,6 @@ class TFSwinImageClassifierOutput(ModelOutput):
|
|||||||
reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_tf_vit.to_2tuple
|
|
||||||
def to_2tuple(x) -> Tuple[Any, Any]:
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
|
def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Partitions the given input into windows.
|
Partitions the given input into windows.
|
||||||
@@ -270,13 +263,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None:
|
def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.patch_embeddings = TFSwinPatchEmbeddings(
|
self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings")
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.embed_dim,
|
|
||||||
name="patch_embeddings",
|
|
||||||
)
|
|
||||||
self.num_patches = self.patch_embeddings.num_patches
|
self.num_patches = self.patch_embeddings.num_patches
|
||||||
self.patch_grid = self.patch_embeddings.grid_size
|
self.patch_grid = self.patch_embeddings.grid_size
|
||||||
self.embed_dim = config.embed_dim
|
self.embed_dim = config.embed_dim
|
||||||
@@ -329,20 +316,25 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
Image to Patch Embedding.
|
Image to Patch Embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config, **kwargs):
|
||||||
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768, **kwargs
|
|
||||||
) -> None:
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.embed_dim
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
|
|
||||||
self.projection = tf.keras.layers.Conv2D(
|
self.projection = tf.keras.layers.Conv2D(
|
||||||
filters=embed_dim, kernel_size=self.patch_size, strides=self.patch_size, padding="valid", name="projection"
|
filters=hidden_size,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
strides=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
name="projection",
|
||||||
)
|
)
|
||||||
|
|
||||||
def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
||||||
@@ -355,7 +347,11 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]:
|
def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]:
|
||||||
_, _, height, width = shape_list(pixel_values)
|
_, num_channels, height, width = shape_list(pixel_values)
|
||||||
|
if tf.executing_eagerly() and num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
|
|
||||||
@@ -460,7 +456,10 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
|||||||
self.num_attention_heads = num_heads
|
self.num_attention_heads = num_heads
|
||||||
self.attention_head_size = int(dim / num_heads)
|
self.attention_head_size = int(dim / num_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
self.window_size = to_2tuple(config.window_size)
|
window_size = config.window_size
|
||||||
|
self.window_size = (
|
||||||
|
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||||||
|
)
|
||||||
|
|
||||||
# get pair-wise relative position index for each token inside the window
|
# get pair-wise relative position index for each token inside the window
|
||||||
coords_h = tf.range(self.window_size[0])
|
coords_h = tf.range(self.window_size[0])
|
||||||
@@ -1252,7 +1251,7 @@ class TFSwinDecoder(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config: SwinConfig, **kwargs):
|
def __init__(self, config: SwinConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.conv2d = tf.keras.layers.Conv2D(
|
self.conv2d = tf.keras.layers.Conv2D(
|
||||||
filters=config.encoder_stride**2 * 3, kernel_size=1, strides=1, name="0"
|
filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0"
|
||||||
)
|
)
|
||||||
self._block_size = config.encoder_stride
|
self._block_size = config.encoder_stride
|
||||||
self.pixel_shuffle = PixelShuffle(self._block_size, name="1")
|
self.pixel_shuffle = PixelShuffle(self._block_size, name="1")
|
||||||
@@ -1280,8 +1279,8 @@ class TFSwinDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
|
"Swin Model with a decoder on top for masked image modeling, as proposed in"
|
||||||
" <https://arxiv.org/abs/2111.09886>`__.",
|
" [SimMIM](https://arxiv.org/abs/2111.09886).",
|
||||||
SWIN_START_DOCSTRING,
|
SWIN_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
|
class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
|
||||||
|
|||||||
@@ -54,23 +54,24 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Stochastic depth implementation
|
# Copied from transformers.models.convnext.modeling_convnext.drop_path
|
||||||
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
|
||||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
|
||||||
"""
|
"""
|
||||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
|
|
||||||
Connect' is a different form of dropout in a separate paper... See discussion:
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||||||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||||||
|
argument.
|
||||||
"""
|
"""
|
||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return x
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
output = x.div(keep_prob) * random_tensor
|
output = input.div(keep_prob) * random_tensor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -78,13 +79,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
|||||||
class VanDropPath(nn.Module):
|
class VanDropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
||||||
|
|
||||||
class VanOverlappingPatchEmbedder(nn.Module):
|
class VanOverlappingPatchEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -82,13 +82,6 @@ class ViltForImagesAndTextClassificationOutput(ModelOutput):
|
|||||||
attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
class ViltEmbeddings(nn.Module):
|
class ViltEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Construct the text and patch embeddings.
|
Construct the text and patch embeddings.
|
||||||
@@ -105,12 +98,7 @@ class ViltEmbeddings(nn.Module):
|
|||||||
self.text_embeddings = TextEmbeddings(config)
|
self.text_embeddings = TextEmbeddings(config)
|
||||||
# patch embeddings
|
# patch embeddings
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = ViltPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
# modality type (text/patch) embeddings
|
# modality type (text/patch) embeddings
|
||||||
@@ -304,26 +292,32 @@ class TextEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class ViltPatchEmbeddings(nn.Module):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class PatchEmbeddings(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
Image to Patch Embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
x = self.projection(pixel_values)
|
x = self.projection(pixel_values)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ VIT_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FlaxPatchEmbeddings(nn.Module):
|
class FlaxViTPatchEmbeddings(nn.Module):
|
||||||
|
|
||||||
config: ViTConfig
|
config: ViTConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
@@ -94,6 +94,7 @@ class FlaxPatchEmbeddings(nn.Module):
|
|||||||
patch_size = self.config.patch_size
|
patch_size = self.config.patch_size
|
||||||
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
self.num_channels = self.config.num_channels
|
||||||
self.projection = nn.Conv(
|
self.projection = nn.Conv(
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_size=(patch_size, patch_size),
|
kernel_size=(patch_size, patch_size),
|
||||||
@@ -104,9 +105,14 @@ class FlaxPatchEmbeddings(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, pixel_values):
|
def __call__(self, pixel_values):
|
||||||
x = self.projection(pixel_values)
|
num_channels = pixel_values.shape[-1]
|
||||||
batch_size, _, _, channels = x.shape
|
if num_channels != self.num_channels:
|
||||||
return jnp.reshape(x, (batch_size, -1, channels))
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
|
embeddings = self.projection(pixel_values)
|
||||||
|
batch_size, _, _, channels = embeddings.shape
|
||||||
|
return jnp.reshape(embeddings, (batch_size, -1, channels))
|
||||||
|
|
||||||
|
|
||||||
class FlaxViTEmbeddings(nn.Module):
|
class FlaxViTEmbeddings(nn.Module):
|
||||||
@@ -117,7 +123,7 @@ class FlaxViTEmbeddings(nn.Module):
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
|
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
|
||||||
self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype)
|
self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = self.param(
|
self.position_embeddings = self.param(
|
||||||
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
|
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
|
||||||
@@ -420,7 +426,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = (1, config.image_size, config.image_size, 3)
|
input_shape = (1, config.image_size, config.image_size, config.num_channels)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
|
|||||||
@@ -52,19 +52,6 @@ _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
|
|||||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
||||||
|
|
||||||
|
|
||||||
# Inspired by
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
|
||||||
# From PyTorch internals
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
|
|
||||||
|
|
||||||
class TFViTEmbeddings(tf.keras.layers.Layer):
|
class TFViTEmbeddings(tf.keras.layers.Layer):
|
||||||
"""
|
"""
|
||||||
Construct the CLS token, position and patch embeddings.
|
Construct the CLS token, position and patch embeddings.
|
||||||
@@ -74,7 +61,7 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config: ViTConfig, **kwargs):
|
def __init__(self, config: ViTConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
|
self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
|
||||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -103,19 +90,21 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size, seq_len, dim = shape_list(embeddings)
|
batch_size, seq_len, dim = shape_list(embeddings)
|
||||||
npatch = seq_len - 1
|
num_patches = seq_len - 1
|
||||||
|
|
||||||
_, N, _ = shape_list(self.position_embeddings)
|
_, num_positions, _ = shape_list(self.position_embeddings)
|
||||||
N -= 1
|
num_positions -= 1
|
||||||
|
|
||||||
if npatch == N and height == width:
|
if num_patches == num_positions and height == width:
|
||||||
return self.position_embeddings
|
return self.position_embeddings
|
||||||
class_pos_embed = self.position_embeddings[:, :1]
|
class_pos_embed = self.position_embeddings[:, :1]
|
||||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
h0 = height // self.config.patch_size
|
h0 = height // self.config.patch_size
|
||||||
w0 = width // self.config.patch_size
|
w0 = width // self.config.patch_size
|
||||||
patch_pos_embed = tf.image.resize(
|
patch_pos_embed = tf.image.resize(
|
||||||
images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)),
|
images=tf.reshape(
|
||||||
|
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
),
|
||||||
size=(h0, w0),
|
size=(h0, w0),
|
||||||
method="bicubic",
|
method="bicubic",
|
||||||
)
|
)
|
||||||
@@ -150,27 +139,31 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
# Based on timm implementation, which can be found here:
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||||
class TFPatchEmbeddings(tf.keras.layers.Layer):
|
class TFViTPatchEmbeddings(tf.keras.layers.Layer):
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ViTConfig, **kwargs):
|
def __init__(self, config: ViTConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
image_size = to_2tuple(config.image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(config.patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.num_channels = config.num_channels
|
self.num_channels = num_channels
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.projection = tf.keras.layers.Conv2D(
|
self.projection = tf.keras.layers.Conv2D(
|
||||||
filters=self.embed_dim,
|
filters=hidden_size,
|
||||||
kernel_size=patch_size,
|
kernel_size=patch_size,
|
||||||
strides=self.patch_size,
|
strides=patch_size,
|
||||||
padding="valid",
|
padding="valid",
|
||||||
data_format="channels_last",
|
data_format="channels_last",
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
@@ -183,8 +176,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
|
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||||
|
if tf.executing_eagerly() and num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
if not interpolate_pos_encoding:
|
if not interpolate_pos_encoding:
|
||||||
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
|
if tf.executing_eagerly():
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
@@ -201,9 +198,9 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
# Change the 2D spatial dimensions to a single temporal dimension.
|
# Change the 2D spatial dimensions to a single temporal dimension.
|
||||||
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
||||||
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
|
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
|
||||||
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
|
embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
|
||||||
|
|
||||||
return x
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class TFViTSelfAttention(tf.keras.layers.Layer):
|
class TFViTSelfAttention(tf.keras.layers.Layer):
|
||||||
|
|||||||
@@ -59,23 +59,9 @@ VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Inspired by
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
|
||||||
# From PyTorch internals
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
|
|
||||||
|
|
||||||
class ViTEmbeddings(nn.Module):
|
class ViTEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
||||||
@@ -83,12 +69,7 @@ class ViTEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = ViTPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
@@ -103,9 +84,9 @@ class ViTEmbeddings(nn.Module):
|
|||||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
"""
|
"""
|
||||||
|
|
||||||
npatch = embeddings.shape[1] - 1
|
num_patches = embeddings.shape[1] - 1
|
||||||
N = self.position_embeddings.shape[1] - 1
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
if npatch == N and height == width:
|
if num_patches == num_positions and height == width:
|
||||||
return self.position_embeddings
|
return self.position_embeddings
|
||||||
class_pos_embed = self.position_embeddings[:, 0]
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
@@ -115,9 +96,11 @@ class ViTEmbeddings(nn.Module):
|
|||||||
# we add a small number to avoid floating point error in the interpolation
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
patch_pos_embed = nn.functional.interpolate(
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
patch_pos_embed,
|
||||||
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||||
mode="bicubic",
|
mode="bicubic",
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
)
|
)
|
||||||
@@ -134,9 +117,9 @@ class ViTEmbeddings(nn.Module):
|
|||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
|
||||||
if bool_masked_pos is not None:
|
if bool_masked_pos is not None:
|
||||||
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
seq_length = embeddings.shape[1]
|
||||||
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
||||||
# replace the masked visual tokens by mask_tokens
|
# replace the masked visual tokens by mask_tokens
|
||||||
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||||
@@ -156,41 +139,42 @@ class ViTEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class ViTPatchEmbeddings(nn.Module):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class PatchEmbeddings(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config):
|
||||||
self,
|
|
||||||
image_size: int = 224,
|
|
||||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
|
||||||
num_channels: int = 3,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
if not interpolate_pos_encoding:
|
if not interpolate_pos_encoding:
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
)
|
)
|
||||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
return x
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class ViTSelfAttention(nn.Module):
|
class ViTSelfAttention(nn.Module):
|
||||||
@@ -524,7 +508,7 @@ class ViTModel(ViTPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> PatchEmbeddings:
|
def get_input_embeddings(self) -> ViTPatchEmbeddings:
|
||||||
return self.embeddings.patch_embeddings
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
||||||
@@ -613,8 +597,8 @@ class ViTPooler(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
|
"ViT Model with a decoder on top for masked image modeling, as proposed in"
|
||||||
" <https://arxiv.org/abs/2111.09886>`__.",
|
" [SimMIM](https://arxiv.org/abs/2111.09886).",
|
||||||
VIT_START_DOCSTRING,
|
VIT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
||||||
@@ -624,7 +608,11 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
|||||||
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
|
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
|
||||||
|
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
|
nn.Conv2d(
|
||||||
|
in_channels=config.hidden_size,
|
||||||
|
out_channels=config.encoder_stride**2 * config.num_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
),
|
||||||
nn.PixelShuffle(config.encoder_stride),
|
nn.PixelShuffle(config.encoder_stride),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -133,13 +133,6 @@ class TFViTMAEForPreTrainingOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.vit.modeling_tf_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
||||||
"""
|
"""
|
||||||
Create 2D sin/cos positional embeddings.
|
Create 2D sin/cos positional embeddings.
|
||||||
@@ -212,7 +205,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config: ViTMAEConfig, **kwargs):
|
def __init__(self, config: ViTMAEConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
|
self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
|
||||||
self.num_patches = self.patch_embeddings.num_patches
|
self.num_patches = self.patch_embeddings.num_patches
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -297,30 +290,30 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
|
|||||||
return embeddings, mask, ids_restore
|
return embeddings, mask, ids_restore
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class TFViTMAEPatchEmbeddings(tf.keras.layers.Layer):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class TFPatchEmbeddings(tf.keras.layers.Layer):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ViTMAEConfig, **kwargs):
|
def __init__(self, config: ViTMAEConfig, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
image_size = to_2tuple(config.image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(config.patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.num_channels = config.num_channels
|
self.num_channels = num_channels
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.projection = tf.keras.layers.Conv2D(
|
self.projection = tf.keras.layers.Conv2D(
|
||||||
filters=self.embed_dim,
|
filters=hidden_size,
|
||||||
kernel_size=self.patch_size,
|
kernel_size=patch_size,
|
||||||
strides=self.patch_size,
|
strides=patch_size,
|
||||||
padding="valid",
|
padding="valid",
|
||||||
data_format="channels_last",
|
data_format="channels_last",
|
||||||
kernel_initializer="glorot_uniform", # following torch.nn.Linear
|
kernel_initializer="glorot_uniform", # following torch.nn.Linear
|
||||||
@@ -330,7 +323,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||||
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
|
if tf.executing_eagerly():
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
||||||
|
" configuration."
|
||||||
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
|||||||
@@ -135,13 +135,6 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.vit.modeling_vit.to_2tuple ViT->ViTMAE
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
||||||
"""
|
"""
|
||||||
Create 2D sin/cos positional embeddings.
|
Create 2D sin/cos positional embeddings.
|
||||||
@@ -213,12 +206,7 @@ class ViTMAEEmbeddings(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = ViTMAEPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
self.num_patches = self.patch_embeddings.num_patches
|
self.num_patches = self.patch_embeddings.num_patches
|
||||||
# fixed sin-cos embedding
|
# fixed sin-cos embedding
|
||||||
self.position_embeddings = nn.Parameter(
|
self.position_embeddings = nn.Parameter(
|
||||||
@@ -291,27 +279,33 @@ class ViTMAEEmbeddings(nn.Module):
|
|||||||
return embeddings, mask, ids_restore
|
return embeddings, mask, ids_restore
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class ViTMAEPatchEmbeddings(nn.Module):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class PatchEmbeddings(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
|||||||
@@ -111,13 +111,6 @@ class YolosObjectDetectionOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.to_2tuple
|
|
||||||
def to_2tuple(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable):
|
|
||||||
return x
|
|
||||||
return (x, x)
|
|
||||||
|
|
||||||
|
|
||||||
class YolosEmbeddings(nn.Module):
|
class YolosEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Construct the CLS token, detection tokens, position and patch embeddings.
|
Construct the CLS token, detection tokens, position and patch embeddings.
|
||||||
@@ -129,12 +122,7 @@ class YolosEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
|
self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
|
||||||
self.patch_embeddings = PatchEmbeddings(
|
self.patch_embeddings = YolosPatchEmbeddings(config)
|
||||||
image_size=config.image_size,
|
|
||||||
patch_size=config.patch_size,
|
|
||||||
num_channels=config.num_channels,
|
|
||||||
embed_dim=config.hidden_size,
|
|
||||||
)
|
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(
|
self.position_embeddings = nn.Parameter(
|
||||||
torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
|
torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
|
||||||
@@ -228,32 +216,35 @@ class InterpolateMidPositionEmbeddings(nn.Module):
|
|||||||
return scale_pos_embed
|
return scale_pos_embed
|
||||||
|
|
||||||
|
|
||||||
# Based on timm implementation, which can be found here:
|
class YolosPatchEmbeddings(nn.Module):
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
||||||
class PatchEmbeddings(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config):
|
||||||
self,
|
|
||||||
image_size: int = 224,
|
|
||||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
|
||||||
num_channels: int = 3,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_size = to_2tuple(image_size)
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
patch_size = to_2tuple(patch_size)
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
|
|
||||||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@@ -620,7 +611,7 @@ class YolosModel(YolosPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> PatchEmbeddings:
|
def get_input_embeddings(self) -> YolosPatchEmbeddings:
|
||||||
return self.embeddings.patch_embeddings
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@@ -1534,3 +1535,9 @@ def check_json_file_has_correct_format(file_path):
|
|||||||
left_indent = len(lines[1]) - len(lines[1].lstrip())
|
left_indent = len(lines[1]) - len(lines[1].lstrip())
|
||||||
assert left_indent == 2
|
assert left_indent == 2
|
||||||
assert lines[-1].strip() == "}"
|
assert lines[-1].strip() == "}"
|
||||||
|
|
||||||
|
|
||||||
|
def to_2tuple(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return (x, x)
|
||||||
|
|||||||
@@ -153,6 +153,16 @@ class BeitModelTester:
|
|||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = BeitForImageClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values, labels=labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
|
def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = BeitForSemanticSegmentation(config)
|
model = BeitForSemanticSegmentation(config)
|
||||||
|
|||||||
@@ -105,7 +105,6 @@ class FlaxBeitModelTester(unittest.TestCase):
|
|||||||
return config, pixel_values, labels
|
return config, pixel_values, labels
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
|
|
||||||
model = FlaxBeitModel(config=config)
|
model = FlaxBeitModel(config=config)
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
@@ -121,6 +120,13 @@ class FlaxBeitModelTester(unittest.TestCase):
|
|||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = FlaxBeitForImageClassification(config)
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -37,10 +37,7 @@ if is_torch_available():
|
|||||||
Data2VecVisionForSemanticSegmentation,
|
Data2VecVisionForSemanticSegmentation,
|
||||||
Data2VecVisionModel,
|
Data2VecVisionModel,
|
||||||
)
|
)
|
||||||
from transformers.models.data2vec.modeling_data2vec_vision import (
|
from transformers.models.data2vec.modeling_data2vec_vision import DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
to_2tuple,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -94,6 +91,10 @@ class Data2VecVisionModelTester:
|
|||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
|
||||||
|
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
self.seq_length = num_patches + 1
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
@@ -131,9 +132,7 @@ class Data2VecVisionModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||||
image_size = to_2tuple(self.image_size)
|
num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
patch_size = to_2tuple(self.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
|
||||||
@@ -286,109 +285,6 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.return_dict = True
|
|
||||||
|
|
||||||
# in Data2VecVision, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
|
||||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
seq_len = num_patches + 1
|
|
||||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
|
||||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
|
||||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
|
||||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
|
||||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = False
|
|
||||||
config.return_dict = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
|
||||||
del inputs_dict["output_attentions"]
|
|
||||||
config.output_attentions = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
attentions = outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
self.assertListEqual(
|
|
||||||
list(attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
out_len = len(outputs)
|
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
self.assertEqual(out_len + 1, len(outputs))
|
|
||||||
|
|
||||||
self_attentions = outputs.attentions
|
|
||||||
|
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
|
||||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
|
||||||
|
|
||||||
expected_num_layers = getattr(
|
|
||||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
|
||||||
|
|
||||||
# Data2VecVision has a different seq_length
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
|
||||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
seq_length = num_patches + 1
|
|
||||||
|
|
||||||
self.assertListEqual(
|
|
||||||
list(hidden_states[0].shape[-2:]),
|
|
||||||
[seq_length, self.model_tester.hidden_size],
|
|
||||||
)
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
|
||||||
|
|
||||||
# check that output_hidden_states also work using config
|
|
||||||
del inputs_dict["output_hidden_states"]
|
|
||||||
config.output_hidden_states = True
|
|
||||||
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
|
||||||
|
|
||||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||||
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||||
|
|||||||
@@ -131,6 +131,25 @@ class DeiTModelTester:
|
|||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||||
|
model = DeiTForMaskedImageModeling(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = DeiTForMaskedImageModeling(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
model = DeiTForImageClassification(config)
|
model = DeiTForImageClassification(config)
|
||||||
@@ -139,6 +158,16 @@ class DeiTModelTester:
|
|||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = DeiTForImageClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values, labels=labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -208,6 +237,10 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_image_modeling(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_image_classification(self):
|
def test_for_image_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch Swin model. """
|
""" Testing suite for the PyTorch Swin model. """
|
||||||
|
|
||||||
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -33,7 +34,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
|
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
|
||||||
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -141,6 +142,25 @@ class SwinModelTester:
|
|||||||
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||||
|
|
||||||
|
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||||
|
model = SwinForMaskedImageModeling(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = SwinForMaskedImageModeling(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
model = SwinForImageClassification(config)
|
model = SwinForImageClassification(config)
|
||||||
@@ -149,6 +169,16 @@ class SwinModelTester:
|
|||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = SwinForImageClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -198,6 +228,14 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_image_modeling(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_image_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
# Swin does not use inputs_embeds
|
# Swin does not use inputs_embeds
|
||||||
pass
|
pass
|
||||||
@@ -299,7 +337,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
|
||||||
# Swin has a different seq_length
|
# Swin has a different seq_length
|
||||||
patch_size = to_2tuple(config.patch_size)
|
patch_size = (
|
||||||
|
config.patch_size
|
||||||
|
if isinstance(config.patch_size, collections.abc.Iterable)
|
||||||
|
else (config.patch_size, config.patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
|
|
||||||
@@ -323,7 +365,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
image_size = (
|
||||||
|
self.model_tester.image_size
|
||||||
|
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
|
||||||
|
else (self.model_tester.image_size, self.model_tester.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
inputs_dict["output_hidden_states"] = True
|
inputs_dict["output_hidden_states"] = True
|
||||||
@@ -339,8 +385,16 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.patch_size = 3
|
config.patch_size = 3
|
||||||
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
image_size = (
|
||||||
patch_size = to_2tuple(config.patch_size)
|
self.model_tester.image_size
|
||||||
|
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
|
||||||
|
else (self.model_tester.image_size, self.model_tester.image_size)
|
||||||
|
)
|
||||||
|
patch_size = (
|
||||||
|
config.patch_size
|
||||||
|
if isinstance(config.patch_size, collections.abc.Iterable)
|
||||||
|
else (config.patch_size, config.patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
|
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
|
||||||
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
|
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
|
||||||
@@ -354,10 +408,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||||
|
|
||||||
def test_for_image_classification(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import SwinConfig
|
from transformers import SwinConfig
|
||||||
from transformers.testing_utils import require_tf, require_vision, slow
|
from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple
|
||||||
from transformers.utils import cached_property, is_tf_available, is_vision_available
|
from transformers.utils import cached_property, is_tf_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -36,7 +36,6 @@ if is_tf_available():
|
|||||||
TFSwinForImageClassification,
|
TFSwinForImageClassification,
|
||||||
TFSwinForMaskedImageModeling,
|
TFSwinForMaskedImageModeling,
|
||||||
TFSwinModel,
|
TFSwinModel,
|
||||||
to_2tuple,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -141,12 +140,34 @@ class TFSwinModelTester:
|
|||||||
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||||
|
|
||||||
|
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||||
|
model = TFSwinForMaskedImageModeling(config=config)
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = TFSwinForMaskedImageModeling(config)
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
model = TFSwinForImageClassification(config)
|
model = TFSwinForImageClassification(config)
|
||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = TFSwinForImageClassification(config)
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, pixel_values, labels = config_and_inputs
|
config, pixel_values, labels = config_and_inputs
|
||||||
@@ -192,6 +213,14 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_image_modeling(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_image_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="Swin does not use inputs_embeds")
|
@unittest.skip(reason="Swin does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
@@ -336,10 +365,6 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||||
|
|
||||||
def test_for_image_classification(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -91,8 +91,7 @@ class FlaxViTModelTester(unittest.TestCase):
|
|||||||
|
|
||||||
return config, pixel_values
|
return config, pixel_values
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
|
||||||
model = FlaxViTModel(config=config)
|
model = FlaxViTModel(config=config)
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||||
@@ -101,6 +100,19 @@ class FlaxViTModelTester(unittest.TestCase):
|
|||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_image_classification(self, config, pixel_values):
|
||||||
|
config.num_labels = self.type_sequence_label_size
|
||||||
|
model = FlaxViTForImageClassification(config=config)
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = FlaxViTForImageClassification(config)
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -123,7 +135,15 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
# We neeed to override this test because ViT's forward signature is different than text models.
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_image_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
# We need to override this test because ViT's forward signature is different than text models.
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -133,6 +133,13 @@ class TFViTModelTester:
|
|||||||
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
|
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = TFViTForImageClassification(config)
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, pixel_values, labels = config_and_inputs
|
config, pixel_values, labels = config_and_inputs
|
||||||
|
|||||||
@@ -120,6 +120,25 @@ class ViTModelTester:
|
|||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||||
|
model = ViTForMaskedImageModeling(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = ViTForMaskedImageModeling(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
model = ViTForImageClassification(config)
|
model = ViTForImageClassification(config)
|
||||||
@@ -128,6 +147,16 @@ class ViTModelTester:
|
|||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = ViTForImageClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -197,6 +226,10 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_image_modeling(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_image_classification(self):
|
def test_for_image_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
@@ -240,3 +273,30 @@ class ViTModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)
|
expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
# ViT models have an `interpolate_pos_encoding` argument in their forward method,
|
||||||
|
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||||
|
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||||
|
# to visualize self-attention on higher resolution images.
|
||||||
|
model = ViTModel.from_pretrained("facebook/dino-vits8").to(torch_device)
|
||||||
|
|
||||||
|
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", size=480)
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = torch.Size((1, 3601, 384))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFViTMAEForPreTraining, TFViTMAEModel
|
from transformers import TFViTMAEForPreTraining, TFViTMAEModel
|
||||||
from transformers.models.vit_mae.modeling_tf_vit_mae import to_2tuple
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -67,6 +66,7 @@ class TFViTMAEModelTester:
|
|||||||
type_sequence_label_size=10,
|
type_sequence_label_size=10,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
|
mask_ratio=0.6,
|
||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@@ -85,8 +85,14 @@ class TFViTMAEModelTester:
|
|||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
|
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
||||||
|
# (we add 1 for the [CLS] token)
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
@@ -116,29 +122,21 @@ class TFViTMAEModelTester:
|
|||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
is_decoder=False,
|
is_decoder=False,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
|
mask_ratio=self.mask_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
model = TFViTMAEModel(config=config)
|
model = TFViTMAEModel(config=config)
|
||||||
result = model(pixel_values, training=False)
|
result = model(pixel_values, training=False)
|
||||||
# expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
# (we add 1 for the [CLS] token)
|
|
||||||
image_size = to_2tuple(self.image_size)
|
|
||||||
patch_size = to_2tuple(self.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
|
|
||||||
|
|
||||||
def create_and_check_for_pretraining(self, config, pixel_values, labels):
|
def create_and_check_for_pretraining(self, config, pixel_values, labels):
|
||||||
model = TFViTMAEForPreTraining(config)
|
model = TFViTMAEForPreTraining(config)
|
||||||
result = model(pixel_values, training=False)
|
result = model(pixel_values, training=False)
|
||||||
# expected sequence length = num_patches
|
# expected sequence length = num_patches
|
||||||
image_size = to_2tuple(self.image_size)
|
num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
patch_size = to_2tuple(self.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
expected_seq_len = num_patches
|
|
||||||
expected_num_channels = self.patch_size**2 * self.num_channels
|
expected_num_channels = self.patch_size**2 * self.num_channels
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
config.num_channels = 1
|
config.num_channels = 1
|
||||||
@@ -147,7 +145,7 @@ class TFViTMAEModelTester:
|
|||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values, training=False)
|
result = model(pixel_values, training=False)
|
||||||
expected_num_channels = self.patch_size**2
|
expected_num_channels = self.patch_size**2
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@@ -179,7 +177,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
|
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
# ViTMAE does not use inputs_embeds
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
@@ -266,114 +263,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
output_for_kw_input = model(**inputs_np, noise=noise)
|
output_for_kw_input = model(**inputs_np, noise=noise)
|
||||||
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.return_dict = True
|
|
||||||
|
|
||||||
# in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
|
||||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
|
||||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
|
||||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
|
||||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
|
||||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
|
||||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = False
|
|
||||||
config.return_dict = True
|
|
||||||
model = model_class(config)
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
|
||||||
del inputs_dict["output_attentions"]
|
|
||||||
config.output_attentions = True
|
|
||||||
model = model_class(config)
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
if chunk_length is not None:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(attentions[0].shape[-4:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
out_len = len(outputs)
|
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
model = model_class(config)
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
|
||||||
|
|
||||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
|
||||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
|
||||||
elif self.is_encoder_decoder:
|
|
||||||
added_hidden_states = 2
|
|
||||||
else:
|
|
||||||
added_hidden_states = 1
|
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
|
||||||
|
|
||||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
|
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
|
||||||
if chunk_length is not None:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-4:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
|
||||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
|
||||||
|
|
||||||
expected_num_layers = getattr(
|
|
||||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
|
||||||
|
|
||||||
# ViTMAE has a different seq_length
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
|
||||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
|
||||||
|
|
||||||
self.assertListEqual(
|
|
||||||
list(hidden_states[0].shape[-2:]),
|
|
||||||
[seq_length, self.model_tester.hidden_size],
|
|
||||||
)
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
|
||||||
|
|
||||||
# check that output_hidden_states also work using config
|
|
||||||
del inputs_dict["output_hidden_states"]
|
|
||||||
config.output_hidden_states = True
|
|
||||||
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
|
||||||
|
|
||||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||||
# to generate masks during test
|
# to generate masks during test
|
||||||
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
|
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import ViTMAEForPreTraining, ViTMAEModel
|
from transformers import ViTMAEForPreTraining, ViTMAEModel
|
||||||
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -64,6 +64,7 @@ class ViTMAEModelTester:
|
|||||||
type_sequence_label_size=10,
|
type_sequence_label_size=10,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
|
mask_ratio=0.6,
|
||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@@ -82,8 +83,14 @@ class ViTMAEModelTester:
|
|||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
|
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
||||||
|
# (we add 1 for the [CLS] token)
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
@@ -109,6 +116,7 @@ class ViTMAEModelTester:
|
|||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
is_decoder=False,
|
is_decoder=False,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
|
mask_ratio=self.mask_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
@@ -116,26 +124,16 @@ class ViTMAEModelTester:
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
# expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
# (we add 1 for the [CLS] token)
|
|
||||||
image_size = to_2tuple(self.image_size)
|
|
||||||
patch_size = to_2tuple(self.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
|
|
||||||
|
|
||||||
def create_and_check_for_pretraining(self, config, pixel_values, labels):
|
def create_and_check_for_pretraining(self, config, pixel_values, labels):
|
||||||
model = ViTMAEForPreTraining(config)
|
model = ViTMAEForPreTraining(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
# expected sequence length = num_patches
|
num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
image_size = to_2tuple(self.image_size)
|
|
||||||
patch_size = to_2tuple(self.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
expected_seq_len = num_patches
|
|
||||||
expected_num_channels = self.patch_size**2 * self.num_channels
|
expected_num_channels = self.patch_size**2 * self.num_channels
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
config.num_channels = 1
|
config.num_channels = 1
|
||||||
@@ -145,7 +143,7 @@ class ViTMAEModelTester:
|
|||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
expected_num_channels = self.patch_size**2
|
expected_num_channels = self.patch_size**2
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@@ -175,8 +173,8 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
# ViTMAE does not use inputs_embeds
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
@@ -208,126 +206,6 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
|
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.return_dict = True
|
|
||||||
|
|
||||||
# in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
|
||||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
|
||||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
|
||||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
|
||||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
|
||||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
|
||||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = False
|
|
||||||
config.return_dict = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
|
||||||
del inputs_dict["output_attentions"]
|
|
||||||
config.output_attentions = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
if chunk_length is not None:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(attentions[0].shape[-4:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
out_len = len(outputs)
|
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
|
||||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
|
||||||
elif self.is_encoder_decoder:
|
|
||||||
added_hidden_states = 2
|
|
||||||
else:
|
|
||||||
added_hidden_states = 1
|
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
|
||||||
|
|
||||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
|
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
|
||||||
if chunk_length is not None:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-4:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
|
||||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
|
||||||
|
|
||||||
expected_num_layers = getattr(
|
|
||||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
|
||||||
|
|
||||||
# ViTMAE has a different seq_length
|
|
||||||
image_size = to_2tuple(self.model_tester.image_size)
|
|
||||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
|
||||||
|
|
||||||
self.assertListEqual(
|
|
||||||
list(hidden_states[0].shape[-2:]),
|
|
||||||
[seq_length, self.model_tester.hidden_size],
|
|
||||||
)
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
|
||||||
|
|
||||||
# check that output_hidden_states also work using config
|
|
||||||
del inputs_dict["output_hidden_states"]
|
|
||||||
config.output_hidden_states = True
|
|
||||||
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
|
||||||
|
|
||||||
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
|
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
|
||||||
# to generate masks during test
|
# to generate masks during test
|
||||||
def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
|
def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import YolosForObjectDetection, YolosModel
|
from transformers import YolosForObjectDetection, YolosModel
|
||||||
from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -86,9 +86,7 @@ class YolosModelTester:
|
|||||||
self.num_detection_tokens = num_detection_tokens
|
self.num_detection_tokens = num_detection_tokens
|
||||||
# we set the expected sequence length (which is used in several tests)
|
# we set the expected sequence length (which is used in several tests)
|
||||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
|
||||||
image_size = to_2tuple(self.image_size)
|
num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
|
||||||
patch_size = to_2tuple(self.patch_size)
|
|
||||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
||||||
self.expected_seq_len = num_patches + 1 + self.num_detection_tokens
|
self.expected_seq_len = num_patches + 1 + self.num_detection_tokens
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user