Enable dynamic resolution input for Swin Transformer and variants (#30656)

* add interpolation of positional encoding support to swin

* add style changes

* use default image processor and make size a dictionary

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove logits testing

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor image size validation logic when interpolation is disabled

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove asserts in modeling

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add dynamic resolution input support to swinv2

* change size to ensure interpolation encoding path is triggered

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

* add dynamic resolution input to donut swin

* add dynamic resolution input to maskformer swin

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Abhiroop Tejomay
2024-05-17 13:38:46 -04:00
committed by GitHub
parent b6eb708bf1
commit 481a957814
6 changed files with 291 additions and 20 deletions

View File

@@ -166,10 +166,48 @@ class DonutSwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward( def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
@@ -180,6 +218,9 @@ class DonutSwinEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
@@ -219,7 +260,9 @@ class DonutSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
@@ -227,6 +270,11 @@ class DonutSwinPatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
@@ -849,6 +897,8 @@ SWIN_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
@@ -899,6 +949,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, DonutSwinModelOutput]: ) -> Union[Tuple, DonutSwinModelOutput]:
r""" r"""
@@ -921,7 +972,9 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,

View File

@@ -163,11 +163,49 @@ class MaskFormerSwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values): def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) """
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, interpolate_pos_encoding):
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
@@ -207,7 +245,9 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
@@ -215,6 +255,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
@@ -780,6 +825,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
interpolate_pos_encoding=False,
return_dict=None, return_dict=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -798,7 +844,9 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values) embedding_output, input_dimensions = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,

View File

@@ -252,10 +252,48 @@ class SwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward( def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
@@ -266,6 +304,9 @@ class SwinEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
@@ -304,7 +345,9 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
@@ -312,6 +355,11 @@ class SwinPatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
@@ -924,6 +972,8 @@ SWIN_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
@@ -981,6 +1031,7 @@ class SwinModel(SwinPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinModelOutput]: ) -> Union[Tuple, SwinModelOutput]:
r""" r"""
@@ -1003,7 +1054,9 @@ class SwinModel(SwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
@@ -1074,6 +1127,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinMaskedImageModelingOutput]: ) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r""" r"""
@@ -1113,6 +1167,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
@@ -1156,6 +1211,14 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
""" """
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet. the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""", """,
SWIN_START_DOCSTRING, SWIN_START_DOCSTRING,
) )
@@ -1188,6 +1251,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinImageClassifierOutput]: ) -> Union[Tuple, SwinImageClassifierOutput]:
r""" r"""
@@ -1203,6 +1267,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )

View File

@@ -295,10 +295,48 @@ class Swinv2Embeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward( def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
@@ -309,6 +347,9 @@ class Swinv2Embeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
@@ -348,7 +389,9 @@ class Swinv2PatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
@@ -356,6 +399,11 @@ class Swinv2PatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
@@ -979,6 +1027,8 @@ SWINV2_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
@@ -1031,6 +1081,7 @@ class Swinv2Model(Swinv2PreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ModelOutput]: ) -> Union[Tuple, Swinv2ModelOutput]:
r""" r"""
@@ -1053,7 +1104,9 @@ class Swinv2Model(Swinv2PreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
@@ -1126,6 +1179,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2MaskedImageModelingOutput]: ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
r""" r"""
@@ -1165,6 +1219,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
@@ -1208,6 +1263,14 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
""" """
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet. of the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""", """,
SWINV2_START_DOCSTRING, SWINV2_START_DOCSTRING,
) )
@@ -1241,6 +1304,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ImageClassifierOutput]: ) -> Union[Tuple, Swinv2ImageClassifierOutput]:
r""" r"""
@@ -1256,6 +1320,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )

View File

@@ -493,6 +493,26 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# Swin models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions.
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
image_processor = self.default_image_processor
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 256, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@require_torch @require_torch
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin): class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):

View File

@@ -485,6 +485,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device) expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions.
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device)
image_processor = self.default_image_processor
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 256, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@require_torch @require_torch
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin): class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):