Enable dynamic resolution input for Beit (#31053)
* Initial attempt * Updates: PR suggestions * Interpolate the relative position bias when interpolate_pos_encoding is True * Add slow tag for the added tests * Add in DATA2VEC_VISION_INPUTS_DOCSTRING
This commit is contained in:
@@ -137,6 +137,12 @@ class BeitEmbeddings(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.mask_token = None
|
self.mask_token = None
|
||||||
self.patch_embeddings = BeitPatchEmbeddings(config)
|
self.patch_embeddings = BeitPatchEmbeddings(config)
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.image_size = (
|
||||||
|
config.image_size
|
||||||
|
if isinstance(config.image_size, collections.abc.Iterable)
|
||||||
|
else (config.image_size, config.image_size)
|
||||||
|
)
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
if config.use_absolute_position_embeddings:
|
if config.use_absolute_position_embeddings:
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
@@ -144,7 +150,55 @@ class BeitEmbeddings(nn.Module):
|
|||||||
self.position_embeddings = None
|
self.position_embeddings = None
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
|
||||||
|
higher resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h = height // self.patch_size
|
||||||
|
w = width // self.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h, w = h + 0.1, w + 0.1
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
|
||||||
|
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
)
|
||||||
|
|
||||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||||
)
|
)
|
||||||
@@ -158,6 +212,9 @@ class BeitEmbeddings(nn.Module):
|
|||||||
|
|
||||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||||
|
|
||||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||||
@@ -191,7 +248,11 @@ class BeitPatchEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
position_embedding: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -251,6 +312,7 @@ class BeitSelfAttention(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
@@ -265,7 +327,9 @@ class BeitSelfAttention(nn.Module):
|
|||||||
|
|
||||||
# Add relative position bias if present.
|
# Add relative position bias if present.
|
||||||
if self.relative_position_bias is not None:
|
if self.relative_position_bias is not None:
|
||||||
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
|
attention_scores = attention_scores + self.relative_position_bias(
|
||||||
|
interpolate_pos_encoding, attention_scores.shape[2]
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
# Add shared relative position bias if provided.
|
# Add shared relative position bias if provided.
|
||||||
if relative_position_bias is not None:
|
if relative_position_bias is not None:
|
||||||
@@ -342,8 +406,11 @@ class BeitAttention(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
|
self_outputs = self.attention(
|
||||||
|
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
|
|
||||||
@@ -407,12 +474,14 @@ class BeitLayer(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
relative_position_bias=relative_position_bias,
|
relative_position_bias=relative_position_bias,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
@@ -471,12 +540,21 @@ class BeitRelativePositionBias(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||||
|
|
||||||
def forward(self) -> torch.Tensor:
|
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
||||||
) # Wh*Ww,Wh*Ww,nH
|
) # Wh*Ww,Wh*Ww,nH
|
||||||
|
|
||||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
relative_position_bias = nn.functional.interpolate(
|
||||||
|
relative_position_bias.unsqueeze(1),
|
||||||
|
size=(dim_size, dim_size),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
return relative_position_bias
|
||||||
|
|
||||||
|
|
||||||
class BeitEncoder(nn.Module):
|
class BeitEncoder(nn.Module):
|
||||||
@@ -508,6 +586,7 @@ class BeitEncoder(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[tuple, BaseModelOutput]:
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
@@ -528,9 +607,13 @@ class BeitEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
relative_position_bias = (
|
relative_position_bias = (
|
||||||
self.relative_position_bias() if self.relative_position_bias is not None else None
|
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
|
||||||
|
if self.relative_position_bias is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||||
)
|
)
|
||||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -607,6 +690,8 @@ BEIT_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -658,6 +743,7 @@ class BeitModel(BeitPreTrainedModel):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, BeitModelOutputWithPooling]:
|
) -> Union[tuple, BeitModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
@@ -680,7 +766,9 @@ class BeitModel(BeitPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
|
embedding_output, (patch_height, patch_width) = self.embeddings(
|
||||||
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -688,6 +776,7 @@ class BeitModel(BeitPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
sequence_output = self.layernorm(sequence_output)
|
sequence_output = self.layernorm(sequence_output)
|
||||||
@@ -755,6 +844,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, MaskedLMOutput]:
|
) -> Union[tuple, MaskedLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -800,6 +890,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -858,6 +949,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, ImageClassifierOutput]:
|
) -> Union[tuple, ImageClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -872,6 +964,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1215,6 +1308,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, SemanticSegmenterOutput]:
|
) -> Union[tuple, SemanticSegmenterOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1255,6 +1349,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=True, # we need the intermediate hidden states
|
output_hidden_states=True, # we need the intermediate hidden states
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,12 @@ class Data2VecVisionEmbeddings(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.mask_token = None
|
self.mask_token = None
|
||||||
self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
|
self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.image_size = (
|
||||||
|
config.image_size
|
||||||
|
if isinstance(config.image_size, collections.abc.Iterable)
|
||||||
|
else (config.image_size, config.image_size)
|
||||||
|
)
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
if config.use_absolute_position_embeddings:
|
if config.use_absolute_position_embeddings:
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
@@ -143,7 +149,55 @@ class Data2VecVisionEmbeddings(nn.Module):
|
|||||||
self.position_embeddings = None
|
self.position_embeddings = None
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
|
||||||
|
higher resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h = height // self.patch_size
|
||||||
|
w = width // self.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h, w = h + 0.1, w + 0.1
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
|
||||||
|
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
|
)
|
||||||
|
|
||||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||||
)
|
)
|
||||||
@@ -157,6 +211,9 @@ class Data2VecVisionEmbeddings(nn.Module):
|
|||||||
|
|
||||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||||
if self.position_embeddings is not None:
|
if self.position_embeddings is not None:
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||||
|
|
||||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||||
@@ -191,7 +248,11 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
position_embedding: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -252,6 +313,7 @@ class Data2VecVisionSelfAttention(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
@@ -266,7 +328,9 @@ class Data2VecVisionSelfAttention(nn.Module):
|
|||||||
|
|
||||||
# Add relative position bias if present.
|
# Add relative position bias if present.
|
||||||
if self.relative_position_bias is not None:
|
if self.relative_position_bias is not None:
|
||||||
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
|
attention_scores = attention_scores + self.relative_position_bias(
|
||||||
|
interpolate_pos_encoding, attention_scores.shape[2]
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
# Add shared relative position bias if provided.
|
# Add shared relative position bias if provided.
|
||||||
if relative_position_bias is not None:
|
if relative_position_bias is not None:
|
||||||
@@ -345,8 +409,11 @@ class Data2VecVisionAttention(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
|
self_outputs = self.attention(
|
||||||
|
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
|
|
||||||
@@ -415,12 +482,14 @@ class Data2VecVisionLayer(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
|
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
relative_position_bias=relative_position_bias,
|
relative_position_bias=relative_position_bias,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
@@ -480,12 +549,21 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||||
|
|
||||||
def forward(self) -> torch.Tensor:
|
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
||||||
) # Wh*Ww,Wh*Ww,nH
|
) # Wh*Ww,Wh*Ww,nH
|
||||||
|
|
||||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
relative_position_bias = nn.functional.interpolate(
|
||||||
|
relative_position_bias.unsqueeze(1),
|
||||||
|
size=(dim_size, dim_size),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
return relative_position_bias
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
|
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
|
||||||
@@ -518,6 +596,7 @@ class Data2VecVisionEncoder(nn.Module):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[tuple, BaseModelOutput]:
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
@@ -538,9 +617,13 @@ class Data2VecVisionEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
relative_position_bias = (
|
relative_position_bias = (
|
||||||
self.relative_position_bias() if self.relative_position_bias is not None else None
|
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
|
||||||
|
if self.relative_position_bias is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||||
)
|
)
|
||||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -618,6 +701,8 @@ DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -670,6 +755,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
|
) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
@@ -692,7 +778,9 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
|
embedding_output, (patch_height, patch_width) = self.embeddings(
|
||||||
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -700,6 +788,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
sequence_output = self.layernorm(sequence_output)
|
sequence_output = self.layernorm(sequence_output)
|
||||||
@@ -772,6 +861,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, ImageClassifierOutput]:
|
) -> Union[tuple, ImageClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -786,6 +876,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1141,6 +1232,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple, SemanticSegmenterOutput]:
|
) -> Union[tuple, SemanticSegmenterOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1181,6 +1273,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=True, # we need the intermediate hidden states
|
output_hidden_states=True, # we need the intermediate hidden states
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -545,6 +545,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_shape = torch.Size((160, 160))
|
expected_shape = torch.Size((160, 160))
|
||||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
model_name = "microsoft/beit-base-patch16-224-pt22k"
|
||||||
|
model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device)
|
||||||
|
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
processor = BeitImageProcessor.from_pretrained(model_name)
|
||||||
|
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||||
|
# images than what the model supports.
|
||||||
|
self.assertFalse(processor.do_center_crop)
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||||
|
model(pixel_values, interpolate_pos_encoding=False)
|
||||||
|
|
||||||
|
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||||
|
# successfully and produce the expected output.
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 1801, 768))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||||
|
|||||||
@@ -341,3 +341,30 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
|
expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
|
||||||
self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2)
|
self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
model_name = "facebook/data2vec-vision-base-ft1k"
|
||||||
|
model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
processor = BeitImageProcessor.from_pretrained("facebook/data2vec-vision-base-ft1k")
|
||||||
|
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||||
|
# images than what the model supports.
|
||||||
|
self.assertFalse(processor.do_center_crop)
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||||
|
model(pixel_values, interpolate_pos_encoding=False)
|
||||||
|
|
||||||
|
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||||
|
# successfully and produce the expected output.
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 1801, 768))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user