Add dynamic resolution input/interpolate position embedding to deit (#31131)

* Added interpolate pos encoding feature and test to deit

* Added interpolate pos encoding feature and test for deit TF model

* readded accidentally delted test for multi_gpu

* storing only patch_size instead of entire config and removed commented code

* Update modeling_tf_deit.py to remove extra line

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Kristen Pereira
2024-06-04 05:29:01 -04:00
committed by GitHub
parent d64e4da713
commit de460e28e1
4 changed files with 161 additions and 14 deletions

View File

@@ -73,9 +73,53 @@ class DeiTEmbeddings(nn.Module):
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# return self.position_embeddings
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_length, _ = embeddings.size() batch_size, seq_length, _ = embeddings.size()
if bool_masked_pos is not None: if bool_masked_pos is not None:
@@ -85,9 +129,16 @@ class DeiTEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1) cls_tokens = self.cls_token.expand(batch_size, -1, -1)
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings position_embedding = self.position_embeddings
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
embeddings = embeddings + position_embedding
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
@@ -120,10 +171,6 @@ class DeiTPatchEmbeddings(nn.Module):
raise ValueError( raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
) )
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2) x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return x
@@ -480,6 +527,8 @@ DEIT_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
""" """
@@ -528,6 +577,7 @@ class DeiTModel(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
@@ -554,7 +604,9 @@ class DeiTModel(DeiTPreTrainedModel):
if pixel_values.dtype != expected_dtype: if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype) pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
@@ -635,6 +687,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, MaskedImageModelingOutput]: ) -> Union[tuple, MaskedImageModelingOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
@@ -674,6 +727,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -742,6 +796,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]: ) -> Union[tuple, ImageClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -784,6 +839,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -901,6 +957,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]: ) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -910,6 +967,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]

View File

@@ -146,9 +146,42 @@ class TFDeiTEmbeddings(keras.layers.Layer):
with tf.name_scope(self.dropout.name): with tf.name_scope(self.dropout.name):
self.dropout.build(None) self.dropout.build(None)
def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = tf.reshape(
patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
)
patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
return tf.concat(
[tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
)
def call( def call(
self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None, training: bool = False self,
pixel_values: tf.Tensor,
bool_masked_pos: tf.Tensor | None = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
_, height, width, _ = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_length, _ = shape_list(embeddings) batch_size, seq_length, _ = shape_list(embeddings)
@@ -162,7 +195,11 @@ class TFDeiTEmbeddings(keras.layers.Layer):
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0) distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1) embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
embeddings = embeddings + self.position_embeddings position_embedding = self.position_embeddings
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
embeddings = embeddings + position_embedding
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
@@ -197,10 +234,7 @@ class TFDeiTPatchEmbeddings(keras.layers.Layer):
raise ValueError( raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
) )
if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values) x = self.projection(pixel_values)
batch_size, height, width, num_channels = shape_list(x) batch_size, height, width, num_channels = shape_list(x)
x = tf.reshape(x, (batch_size, height * width, num_channels)) x = tf.reshape(x, (batch_size, height * width, num_channels))
@@ -599,6 +633,7 @@ class TFDeiTMainLayer(keras.layers.Layer):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]: ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -621,7 +656,12 @@ class TFDeiTMainLayer(keras.layers.Layer):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask) head_mask = self.get_head_mask(head_mask)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, training=training) embedding_output = self.embeddings(
pixel_values,
bool_masked_pos=bool_masked_pos,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
@@ -705,6 +745,8 @@ DEIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
@@ -741,6 +783,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[Tuple, TFBaseModelOutputWithPooling]: ) -> Union[Tuple, TFBaseModelOutputWithPooling]:
outputs = self.deit( outputs = self.deit(
@@ -750,6 +793,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
return outputs return outputs
@@ -869,6 +913,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[tuple, TFMaskedImageModelingOutput]: ) -> Union[tuple, TFMaskedImageModelingOutput]:
r""" r"""
@@ -909,6 +954,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
@@ -1003,6 +1049,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[tf.Tensor, TFImageClassifierOutput]: ) -> Union[tf.Tensor, TFImageClassifierOutput]:
r""" r"""
@@ -1046,6 +1093,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )
@@ -1126,6 +1174,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False, training: bool = False,
) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]: ) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1136,6 +1185,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training, training=training,
) )

View File

@@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to(
torch_device
)
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
@slow @slow
@require_accelerate @require_accelerate
@require_torch_accelerator @require_torch_accelerator

View File

@@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase):
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861]) expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="tf", do_center_crop=False)
# forward pass
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)