From 8c5b3c19cf240ec6d4e195b830b0283a1ff32570 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Thu, 9 May 2024 03:23:39 -0700 Subject: [PATCH] Enable dynamic resolution for vivit (#30630) * feat: enable dynamic resolution for vivit * fix: formatting * remove: print statement for testing * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vivit/test_modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vivit/test_modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vivit/test_modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix: style check --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/vivit/modeling_vivit.py | 72 ++++++++++++++++--- tests/models/vivit/test_modeling_vivit.py | 23 ++++++ 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index aa96237356..ef94b836a4 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -67,11 +67,12 @@ class VivitTubeletEmbeddings(nn.Module): config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size ) - def forward(self, pixel_values): + def forward(self, pixel_values, interpolate_pos_encoding: bool = False): batch_size, num_frames, num_channels, height, width = pixel_values.shape - if height != self.image_size or width != self.image_size: + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." + f"Image image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." ) # permute to (batch_size, num_channels, num_frames, height, width) @@ -102,16 +103,50 @@ class VivitEmbeddings(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config - def forward(self, pixel_values): - batch_size = pixel_values.shape[0] - embeddings = self.patch_embeddings(pixel_values) + def interpolate_pos_encoding(self, embeddings, height, width): + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values, interpolate_pos_encoding: bool = False): + batch_size, num_frames, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) cls_tokens = self.cls_token.tile([batch_size, 1, 1]) - embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -437,6 +472,8 @@ VIVIT_INPUTS_DOCSTRING = r""" output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -482,6 +519,7 @@ class VivitModel(VivitPreTrainedModel): head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]: r""" @@ -571,7 +609,7 @@ class VivitModel(VivitPreTrainedModel): head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings(pixel_values) + embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( embedding_output, @@ -596,8 +634,18 @@ class VivitModel(VivitPreTrainedModel): @add_start_docstrings( - """ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the -[CLS] token) e.g. for Kinetics-400.""", + """ + ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the +[CLS] token) e.g. for Kinetics-400. + + + + Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, VIVIT_START_DOCSTRING, ) class VivitForVideoClassification(VivitPreTrainedModel): @@ -622,6 +670,7 @@ class VivitForVideoClassification(VivitPreTrainedModel): labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]: r""" @@ -715,6 +764,7 @@ class VivitForVideoClassification(VivitPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/vivit/test_modeling_vivit.py b/tests/models/vivit/test_modeling_vivit.py index 9b299c9afa..cbb45731ae 100644 --- a/tests/models/vivit/test_modeling_vivit.py +++ b/tests/models/vivit/test_modeling_vivit.py @@ -353,3 +353,26 @@ class VivitModelIntegrationTest(unittest.TestCase): expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # Vivit models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device) + + image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2") + video = prepare_video() + inputs = image_processor( + video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt" + ) + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits shape + expected_shape = torch.Size((1, 3137, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape)