From 42d8dd8716d45291733191464e249a83b93e4d86 Mon Sep 17 00:00:00 2001 From: Yixiang Gao Date: Fri, 24 May 2024 05:13:58 -0500 Subject: [PATCH] Perceiver interpolate position embedding (#30979) * add test that currently fails * test passed * all perceiver passed * fixup, style, quality, repo-consistency, all passed * Apply suggestions from code review: default to False + compute sqrt once only Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix a minor bracket * replace dim with self._num_channels * add arguments to the rest preprocessors --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/perceiver/modeling_perceiver.py | 83 ++++++++++++++++--- .../perceiver/test_modeling_perceiver.py | 20 +++++ 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index 4921e292d6..4f9f0c05af 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -699,13 +699,24 @@ PERCEIVER_INPUTS_DOCSTRING = r""" output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( - """The Perceiver: a scalable, fully attentional architecture.""", + """The Perceiver: a scalable, fully attentional architecture. + + + + Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, PERCEIVER_MODEL_START_DOCSTRING, ) class PerceiverModel(PerceiverPreTrainedModel): @@ -754,6 +765,7 @@ class PerceiverModel(PerceiverPreTrainedModel): head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, PerceiverModelOutput]: r""" @@ -857,7 +869,9 @@ class PerceiverModel(PerceiverPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.input_preprocessor is not None: - inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs) + inputs, modality_sizes, inputs_without_pos = self.input_preprocessor( + inputs, interpolate_pos_encoding=interpolate_pos_encoding + ) else: modality_sizes = None inputs_without_pos = None @@ -1247,6 +1261,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, ) -> Union[Tuple, PerceiverClassifierOutput]: @@ -1295,6 +1310,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) logits = outputs.logits if return_dict else outputs[0] @@ -2749,9 +2765,31 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): def output_size(self, *args, **kwargs) -> int: return self._num_channels - def forward(self, batch_size: int) -> torch.Tensor: + def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + num_positions = position_embeddings.shape[0] + new_height = new_width = math.sqrt(num_positions) + position_embeddings = position_embeddings.reshape( + 1, int(new_height), int(new_width), self._num_channels + ).permute(0, 3, 1, 2) + position_embeddings = nn.functional.interpolate( + position_embeddings, + scale_factor=(height / new_height, width / new_width), + mode="bicubic", + align_corners=False, + ) + position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0) + return position_embeddings + + def forward( + self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size = None + ) -> torch.Tensor: position_embeddings = self.position_embeddings + if interpolate_pos_encoding: + height, width = input_size + height, width = height + 0.1, width + 0.1 + position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width) + if batch_size is not None: position_embeddings = position_embeddings.expand(batch_size, -1, -1) return position_embeddings @@ -2859,7 +2897,13 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): def num_channels(self) -> int: return self.config.d_model - def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + def forward( + self, + inputs: torch.LongTensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): embeddings_without_pos = self.embeddings(inputs) seq_length = inputs.shape[1] @@ -3139,7 +3183,9 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): return inp_dim + pos_dim - def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True): + def _build_network_inputs( + self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False + ): """ Construct the final input, including position encoding. @@ -3147,6 +3193,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): """ batch_size = inputs.shape[0] + input_size = inputs.shape[1:3] index_dims = inputs.shape[1:-1] indices = np.prod(index_dims) @@ -3156,7 +3203,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): # Construct the position encoding. if self.position_encoding_type == "trainable": - pos_enc = self.position_embeddings(batch_size) + pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size) elif self.position_encoding_type == "fourier": pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) @@ -3174,7 +3221,13 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): inputs_with_pos = inputs + pos_enc return inputs_with_pos, inputs - def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + def forward( + self, + inputs: torch.Tensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): if self.prep_type == "conv": # Convnet image featurization. # Downsamples spatially by a factor of 4 @@ -3218,7 +3271,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): else: raise ValueError("Unsupported data format for conv1x1.") - inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d) + inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding) modality_sizes = None # Size for each modality, only needed for multimodal return inputs, modality_sizes, inputs_without_pos @@ -3338,7 +3391,13 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor): return inputs_with_pos, inputs - def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + def forward( + self, + inputs: torch.Tensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch]) inputs, inputs_without_pos = self._build_network_inputs(inputs) @@ -3391,7 +3450,11 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor): return common_channel_size def forward( - self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True + self, + inputs: Mapping[str, torch.Tensor], + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, ) -> PreprocessorOutputType: padded = {} modality_sizes = {} diff --git a/tests/models/perceiver/test_modeling_perceiver.py b/tests/models/perceiver/test_modeling_perceiver.py index 530906388f..379b4774ca 100644 --- a/tests/models/perceiver/test_modeling_perceiver.py +++ b/tests/models/perceiver/test_modeling_perceiver.py @@ -1031,3 +1031,23 @@ class PerceiverModelIntegrationTest(unittest.TestCase): ) self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_interpolate_pos_encoding(self): + image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384}) + model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned") + model.to(torch_device) + + # prepare inputs + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device) + input_mask = None + + # forward pass + with torch.no_grad(): + outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True) + logits = outputs.logits + + # verify logits + expected_shape = torch.Size((1, model.config.num_labels)) + self.assertEqual(logits.shape, expected_shape)