Enable dynamic resolution input for Swin Transformer and variants (#30656)
* add interpolation of positional encoding support to swin * add style changes * use default image processor and make size a dictionary Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove logits testing Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor image size validation logic when interpolation is disabled Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove asserts in modeling Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add dynamic resolution input support to swinv2 * change size to ensure interpolation encoding path is triggered * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False * add dynamic resolution input to donut swin * add dynamic resolution input to maskformer swin --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -166,10 +166,48 @@ class DonutSwinEmbeddings(nn.Module):
|
|||||||
self.norm = nn.LayerNorm(config.embed_dim)
|
self.norm = nn.LayerNorm(config.embed_dim)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor],
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
_, num_channels, height, width = pixel_values.shape
|
||||||
|
embeddings, output_dimensions = self.patch_embeddings(
|
||||||
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
@@ -180,6 +218,9 @@ class DonutSwinEmbeddings(nn.Module):
|
|||||||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||||
|
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
embeddings = embeddings + self.position_embeddings
|
embeddings = embeddings + self.position_embeddings
|
||||||
|
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
@@ -219,7 +260,9 @@ class DonutSwinPatchEmbeddings(nn.Module):
|
|||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
def forward(
|
||||||
|
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -227,6 +270,11 @@ class DonutSwinPatchEmbeddings(nn.Module):
|
|||||||
)
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
@@ -849,6 +897,8 @@ SWIN_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -899,6 +949,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, DonutSwinModelOutput]:
|
) -> Union[Tuple, DonutSwinModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -921,7 +972,9 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||||
|
|
||||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
embedding_output, input_dimensions = self.embeddings(
|
||||||
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
|
|||||||
@@ -163,11 +163,49 @@ class MaskFormerSwinEmbeddings(nn.Module):
|
|||||||
self.norm = nn.LayerNorm(config.embed_dim)
|
self.norm = nn.LayerNorm(config.embed_dim)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, interpolate_pos_encoding):
|
||||||
|
_, num_channels, height, width = pixel_values.shape
|
||||||
|
embeddings, output_dimensions = self.patch_embeddings(
|
||||||
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
|
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
embeddings = embeddings + self.position_embeddings
|
embeddings = embeddings + self.position_embeddings
|
||||||
|
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
@@ -207,7 +245,9 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
|||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
def forward(
|
||||||
|
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -215,6 +255,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
|
|||||||
)
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
@@ -780,6 +825,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
interpolate_pos_encoding=False,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
@@ -798,7 +844,9 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||||
|
|
||||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
embedding_output, input_dimensions = self.embeddings(
|
||||||
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
|
|||||||
@@ -252,10 +252,48 @@ class SwinEmbeddings(nn.Module):
|
|||||||
self.norm = nn.LayerNorm(config.embed_dim)
|
self.norm = nn.LayerNorm(config.embed_dim)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor],
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
_, num_channels, height, width = pixel_values.shape
|
||||||
|
embeddings, output_dimensions = self.patch_embeddings(
|
||||||
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
@@ -266,6 +304,9 @@ class SwinEmbeddings(nn.Module):
|
|||||||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||||
|
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
embeddings = embeddings + self.position_embeddings
|
embeddings = embeddings + self.position_embeddings
|
||||||
|
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
@@ -304,7 +345,9 @@ class SwinPatchEmbeddings(nn.Module):
|
|||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
def forward(
|
||||||
|
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -312,6 +355,11 @@ class SwinPatchEmbeddings(nn.Module):
|
|||||||
)
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
@@ -924,6 +972,8 @@ SWIN_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -981,6 +1031,7 @@ class SwinModel(SwinPreTrainedModel):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, SwinModelOutput]:
|
) -> Union[Tuple, SwinModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1003,7 +1054,9 @@ class SwinModel(SwinPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||||
|
|
||||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
embedding_output, input_dimensions = self.embeddings(
|
||||||
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -1074,6 +1127,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
|
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1113,6 +1167,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1156,6 +1211,14 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||||
the [CLS] token) e.g. for ImageNet.
|
the [CLS] token) e.g. for ImageNet.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
|
||||||
|
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||||
|
position embeddings to the higher resolution.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
""",
|
""",
|
||||||
SWIN_START_DOCSTRING,
|
SWIN_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
@@ -1188,6 +1251,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, SwinImageClassifierOutput]:
|
) -> Union[Tuple, SwinImageClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1203,6 +1267,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -295,10 +295,48 @@ class Swinv2Embeddings(nn.Module):
|
|||||||
self.norm = nn.LayerNorm(config.embed_dim)
|
self.norm = nn.LayerNorm(config.embed_dim)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor],
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
_, num_channels, height, width = pixel_values.shape
|
||||||
|
embeddings, output_dimensions = self.patch_embeddings(
|
||||||
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
embeddings = self.norm(embeddings)
|
embeddings = self.norm(embeddings)
|
||||||
batch_size, seq_len, _ = embeddings.size()
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
@@ -309,6 +347,9 @@ class Swinv2Embeddings(nn.Module):
|
|||||||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||||
|
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
embeddings = embeddings + self.position_embeddings
|
embeddings = embeddings + self.position_embeddings
|
||||||
|
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
@@ -348,7 +389,9 @@ class Swinv2PatchEmbeddings(nn.Module):
|
|||||||
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
def forward(
|
||||||
|
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[int]]:
|
||||||
_, num_channels, height, width = pixel_values.shape
|
_, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -356,6 +399,11 @@ class Swinv2PatchEmbeddings(nn.Module):
|
|||||||
)
|
)
|
||||||
# pad the input to be divisible by self.patch_size, if needed
|
# pad the input to be divisible by self.patch_size, if needed
|
||||||
pixel_values = self.maybe_pad(pixel_values, height, width)
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
)
|
||||||
embeddings = self.projection(pixel_values)
|
embeddings = self.projection(pixel_values)
|
||||||
_, _, height, width = embeddings.shape
|
_, _, height, width = embeddings.shape
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
@@ -979,6 +1027,8 @@ SWINV2_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -1031,6 +1081,7 @@ class Swinv2Model(Swinv2PreTrainedModel):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, Swinv2ModelOutput]:
|
) -> Union[Tuple, Swinv2ModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1053,7 +1104,9 @@ class Swinv2Model(Swinv2PreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||||||
|
|
||||||
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
embedding_output, input_dimensions = self.embeddings(
|
||||||
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -1126,6 +1179,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
|
) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1165,6 +1219,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1208,6 +1263,14 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
||||||
of the [CLS] token) e.g. for ImageNet.
|
of the [CLS] token) e.g. for ImageNet.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
|
||||||
|
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||||
|
position embeddings to the higher resolution.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
""",
|
""",
|
||||||
SWINV2_START_DOCSTRING,
|
SWINV2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
@@ -1241,6 +1304,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, Swinv2ImageClassifierOutput]:
|
) -> Union[Tuple, Swinv2ImageClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1256,6 +1320,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -493,6 +493,26 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
# Swin models have an `interpolate_pos_encoding` argument in their forward method,
|
||||||
|
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||||
|
# the model on higher resolutions.
|
||||||
|
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = torch.Size((1, 256, 768))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||||
|
|||||||
@@ -485,6 +485,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
|
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
|
||||||
|
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||||
|
# the model on higher resolutions.
|
||||||
|
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = torch.Size((1, 256, 768))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
|
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||||
|
|||||||
Reference in New Issue
Block a user