From 681183974acb987d4bb7037de3e82078bd0308a1 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 6 Jun 2024 18:47:41 +0500 Subject: [PATCH] Enable dynamic resolution input for Beit (#31053) * Initial attempt * Updates: PR suggestions * Interpolate the relative position bias when interpolate_pos_encoding is True * Add slow tag for the added tests * Add in DATA2VEC_VISION_INPUTS_DOCSTRING --- src/transformers/models/beit/modeling_beit.py | 115 ++++++++++++++++-- .../data2vec/modeling_data2vec_vision.py | 113 +++++++++++++++-- tests/models/beit/test_modeling_beit.py | 25 ++++ .../data2vec/test_modeling_data2vec_vision.py | 27 ++++ 4 files changed, 260 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 64efc6de63..184ab55822 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -137,6 +137,12 @@ class BeitEmbeddings(nn.Module): else: self.mask_token = None self.patch_embeddings = BeitPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -144,7 +150,55 @@ class BeitEmbeddings(nn.Module): self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows the model to interpolate the pre-trained position encodings so that it can be used on + higher resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h = height // self.patch_size + w = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h, w = h + 0.1, w + 0.1 + + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + _, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings, (patch_height, patch_width) = self.patch_embeddings( pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None ) @@ -158,7 +212,10 @@ class BeitEmbeddings(nn.Module): cls_tokens = self.cls_token.expand(batch_size, -1, -1) if self.position_embeddings is not None: - cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] + if interpolate_pos_encoding: + cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) + else: + cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -191,7 +248,11 @@ class BeitPatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.Tensor, + position_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -251,6 +312,7 @@ class BeitSelfAttention(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -265,7 +327,9 @@ class BeitSelfAttention(nn.Module): # Add relative position bias if present. if self.relative_position_bias is not None: - attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + attention_scores = attention_scores + self.relative_position_bias( + interpolate_pos_encoding, attention_scores.shape[2] + ).unsqueeze(0) # Add shared relative position bias if provided. if relative_position_bias is not None: @@ -342,8 +406,11 @@ class BeitAttention(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding + ) attention_output = self.output(self_outputs[0], hidden_states) @@ -407,12 +474,14 @@ class BeitLayer(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights @@ -471,12 +540,21 @@ class BeitRelativePositionBias(nn.Module): self.register_buffer("relative_position_index", relative_position_index, persistent=False) - def forward(self) -> torch.Tensor: + def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 ) # Wh*Ww,Wh*Ww,nH - return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + if interpolate_pos_encoding: + relative_position_bias = nn.functional.interpolate( + relative_position_bias.unsqueeze(1), + size=(dim_size, dim_size), + mode="bilinear", + align_corners=False, + ).squeeze(1) + + return relative_position_bias class BeitEncoder(nn.Module): @@ -508,6 +586,7 @@ class BeitEncoder(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, + interpolate_pos_encoding: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None @@ -528,9 +607,13 @@ class BeitEncoder(nn.Module): ) else: relative_position_bias = ( - self.relative_position_bias() if self.relative_position_bias is not None else None + self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1]) + if self.relative_position_bias is not None + else None + ) + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding ) - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) hidden_states = layer_outputs[0] @@ -607,6 +690,8 @@ BEIT_INPUTS_DOCSTRING = r""" output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -658,6 +743,7 @@ class BeitModel(BeitPreTrainedModel): head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, BeitModelOutputWithPooling]: r""" @@ -680,7 +766,9 @@ class BeitModel(BeitPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos) + embedding_output, (patch_height, patch_width) = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -688,6 +776,7 @@ class BeitModel(BeitPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) @@ -755,6 +844,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, MaskedLMOutput]: r""" @@ -800,6 +890,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -858,6 +949,7 @@ class BeitForImageClassification(BeitPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: r""" @@ -872,6 +964,7 @@ class BeitForImageClassification(BeitPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1215,6 +1308,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, SemanticSegmenterOutput]: r""" @@ -1255,6 +1349,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index df793e67d1..a79810d0c5 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -136,6 +136,12 @@ class Data2VecVisionEmbeddings(nn.Module): else: self.mask_token = None self.patch_embeddings = Data2VecVisionPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -143,7 +149,55 @@ class Data2VecVisionEmbeddings(nn.Module): self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows the model to interpolate the pre-trained position encodings so that it can be used on + higher resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h = height // self.patch_size + w = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h, w = h + 0.1, w + 0.1 + + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + _, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings, (patch_height, patch_width) = self.patch_embeddings( pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None ) @@ -157,7 +211,10 @@ class Data2VecVisionEmbeddings(nn.Module): cls_tokens = self.cls_token.expand(batch_size, -1, -1) if self.position_embeddings is not None: - cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] + if interpolate_pos_encoding: + cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) + else: + cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -191,7 +248,11 @@ class Data2VecVisionPatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.Tensor, + position_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -252,6 +313,7 @@ class Data2VecVisionSelfAttention(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -266,7 +328,9 @@ class Data2VecVisionSelfAttention(nn.Module): # Add relative position bias if present. if self.relative_position_bias is not None: - attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + attention_scores = attention_scores + self.relative_position_bias( + interpolate_pos_encoding, attention_scores.shape[2] + ).unsqueeze(0) # Add shared relative position bias if provided. if relative_position_bias is not None: @@ -345,8 +409,11 @@ class Data2VecVisionAttention(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) + self_outputs = self.attention( + hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding + ) attention_output = self.output(self_outputs[0], hidden_states) @@ -415,12 +482,14 @@ class Data2VecVisionLayer(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights @@ -480,12 +549,21 @@ class Data2VecVisionRelativePositionBias(nn.Module): self.register_buffer("relative_position_index", relative_position_index, persistent=False) - def forward(self) -> torch.Tensor: + def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 ) # Wh*Ww,Wh*Ww,nH - return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + if interpolate_pos_encoding: + relative_position_bias = nn.functional.interpolate( + relative_position_bias.unsqueeze(1), + size=(dim_size, dim_size), + mode="bilinear", + align_corners=False, + ).squeeze(1) + + return relative_position_bias # Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision @@ -518,6 +596,7 @@ class Data2VecVisionEncoder(nn.Module): head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, + interpolate_pos_encoding: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None @@ -538,9 +617,13 @@ class Data2VecVisionEncoder(nn.Module): ) else: relative_position_bias = ( - self.relative_position_bias() if self.relative_position_bias is not None else None + self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1]) + if self.relative_position_bias is not None + else None + ) + layer_outputs = layer_module( + hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding ) - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) hidden_states = layer_outputs[0] @@ -618,6 +701,8 @@ DATA2VEC_VISION_INPUTS_DOCSTRING = r""" output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -670,6 +755,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]: r""" @@ -692,7 +778,9 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos) + embedding_output, (patch_height, patch_width) = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -700,6 +788,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) @@ -772,6 +861,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: r""" @@ -786,6 +876,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1141,6 +1232,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel): labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, SemanticSegmenterOutput]: r""" @@ -1181,6 +1273,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 1010c6007d..0fd17efaf6 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -545,6 +545,31 @@ class BeitModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((160, 160)) self.assertEqual(segmentation[0].shape, expected_shape) + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "microsoft/beit-base-patch16-224-pt22k" + model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + processor = BeitImageProcessor.from_pretrained(model_name) + inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480}) + pixel_values = inputs.pixel_values.to(torch_device) + + # with interpolate_pos_encoding being False an exception should be raised with higher resolution + # images than what the model supports. + self.assertFalse(processor.do_center_crop) + with torch.no_grad(): + with self.assertRaises(ValueError, msg="doesn't match model"): + model(pixel_values, interpolate_pos_encoding=False) + + # with interpolate_pos_encoding being True the model should process the higher resolution image + # successfully and produce the expected output. + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + expected_shape = torch.Size((1, 1801, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + @require_torch class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin): diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 99cbd66fbb..fabf543c02 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -341,3 +341,30 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase): expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]] self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2) + + @slow + def test_inference_interpolate_pos_encoding(self): + model_name = "facebook/data2vec-vision-base-ft1k" + model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to( + torch_device + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + processor = BeitImageProcessor.from_pretrained("facebook/data2vec-vision-base-ft1k") + inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480}) + pixel_values = inputs.pixel_values.to(torch_device) + + # with interpolate_pos_encoding being False an exception should be raised with higher resolution + # images than what the model supports. + self.assertFalse(processor.do_center_crop) + with torch.no_grad(): + with self.assertRaises(ValueError, msg="doesn't match model"): + model(pixel_values, interpolate_pos_encoding=False) + + # with interpolate_pos_encoding being True the model should process the higher resolution image + # successfully and produce the expected output. + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + expected_shape = torch.Size((1, 1801, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape)