Perceiver interpolate position embedding (#30979)
* add test that currently fails * test passed * all perceiver passed * fixup, style, quality, repo-consistency, all passed * Apply suggestions from code review: default to False + compute sqrt once only Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix a minor bracket * replace dim with self._num_channels * add arguments to the rest preprocessors --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -699,13 +699,24 @@ PERCEIVER_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""The Perceiver: a scalable, fully attentional architecture.""",
|
"""The Perceiver: a scalable, fully attentional architecture.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
|
||||||
|
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||||
|
position embeddings to the higher resolution.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
""",
|
||||||
PERCEIVER_MODEL_START_DOCSTRING,
|
PERCEIVER_MODEL_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class PerceiverModel(PerceiverPreTrainedModel):
|
class PerceiverModel(PerceiverPreTrainedModel):
|
||||||
@@ -754,6 +765,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, PerceiverModelOutput]:
|
) -> Union[Tuple, PerceiverModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -857,7 +869,9 @@ class PerceiverModel(PerceiverPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if self.input_preprocessor is not None:
|
if self.input_preprocessor is not None:
|
||||||
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs)
|
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
|
||||||
|
inputs, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
modality_sizes = None
|
modality_sizes = None
|
||||||
inputs_without_pos = None
|
inputs_without_pos = None
|
||||||
@@ -1247,6 +1261,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
) -> Union[Tuple, PerceiverClassifierOutput]:
|
) -> Union[Tuple, PerceiverClassifierOutput]:
|
||||||
@@ -1295,6 +1310,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
logits = outputs.logits if return_dict else outputs[0]
|
logits = outputs.logits if return_dict else outputs[0]
|
||||||
@@ -2749,9 +2765,31 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
|
|||||||
def output_size(self, *args, **kwargs) -> int:
|
def output_size(self, *args, **kwargs) -> int:
|
||||||
return self._num_channels
|
return self._num_channels
|
||||||
|
|
||||||
def forward(self, batch_size: int) -> torch.Tensor:
|
def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
num_positions = position_embeddings.shape[0]
|
||||||
|
new_height = new_width = math.sqrt(num_positions)
|
||||||
|
position_embeddings = position_embeddings.reshape(
|
||||||
|
1, int(new_height), int(new_width), self._num_channels
|
||||||
|
).permute(0, 3, 1, 2)
|
||||||
|
position_embeddings = nn.functional.interpolate(
|
||||||
|
position_embeddings,
|
||||||
|
scale_factor=(height / new_height, width / new_width),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
|
||||||
|
return position_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size = None
|
||||||
|
) -> torch.Tensor:
|
||||||
position_embeddings = self.position_embeddings
|
position_embeddings = self.position_embeddings
|
||||||
|
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
height, width = input_size
|
||||||
|
height, width = height + 0.1, width + 0.1
|
||||||
|
position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)
|
||||||
|
|
||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
position_embeddings = position_embeddings.expand(batch_size, -1, -1)
|
position_embeddings = position_embeddings.expand(batch_size, -1, -1)
|
||||||
return position_embeddings
|
return position_embeddings
|
||||||
@@ -2859,7 +2897,13 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
|
|||||||
def num_channels(self) -> int:
|
def num_channels(self) -> int:
|
||||||
return self.config.d_model
|
return self.config.d_model
|
||||||
|
|
||||||
def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.LongTensor,
|
||||||
|
pos: Optional[torch.Tensor] = None,
|
||||||
|
network_input_is_1d: bool = True,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
):
|
||||||
embeddings_without_pos = self.embeddings(inputs)
|
embeddings_without_pos = self.embeddings(inputs)
|
||||||
|
|
||||||
seq_length = inputs.shape[1]
|
seq_length = inputs.shape[1]
|
||||||
@@ -3139,7 +3183,9 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
|||||||
|
|
||||||
return inp_dim + pos_dim
|
return inp_dim + pos_dim
|
||||||
|
|
||||||
def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
|
def _build_network_inputs(
|
||||||
|
self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Construct the final input, including position encoding.
|
Construct the final input, including position encoding.
|
||||||
|
|
||||||
@@ -3147,6 +3193,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
batch_size = inputs.shape[0]
|
batch_size = inputs.shape[0]
|
||||||
|
input_size = inputs.shape[1:3]
|
||||||
index_dims = inputs.shape[1:-1]
|
index_dims = inputs.shape[1:-1]
|
||||||
indices = np.prod(index_dims)
|
indices = np.prod(index_dims)
|
||||||
|
|
||||||
@@ -3156,7 +3203,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
|||||||
|
|
||||||
# Construct the position encoding.
|
# Construct the position encoding.
|
||||||
if self.position_encoding_type == "trainable":
|
if self.position_encoding_type == "trainable":
|
||||||
pos_enc = self.position_embeddings(batch_size)
|
pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
|
||||||
elif self.position_encoding_type == "fourier":
|
elif self.position_encoding_type == "fourier":
|
||||||
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
|
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
|
||||||
|
|
||||||
@@ -3174,7 +3221,13 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
|||||||
inputs_with_pos = inputs + pos_enc
|
inputs_with_pos = inputs + pos_enc
|
||||||
return inputs_with_pos, inputs
|
return inputs_with_pos, inputs
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
pos: Optional[torch.Tensor] = None,
|
||||||
|
network_input_is_1d: bool = True,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
):
|
||||||
if self.prep_type == "conv":
|
if self.prep_type == "conv":
|
||||||
# Convnet image featurization.
|
# Convnet image featurization.
|
||||||
# Downsamples spatially by a factor of 4
|
# Downsamples spatially by a factor of 4
|
||||||
@@ -3218,7 +3271,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported data format for conv1x1.")
|
raise ValueError("Unsupported data format for conv1x1.")
|
||||||
|
|
||||||
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
|
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
|
||||||
modality_sizes = None # Size for each modality, only needed for multimodal
|
modality_sizes = None # Size for each modality, only needed for multimodal
|
||||||
|
|
||||||
return inputs, modality_sizes, inputs_without_pos
|
return inputs, modality_sizes, inputs_without_pos
|
||||||
@@ -3338,7 +3391,13 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
|
|||||||
|
|
||||||
return inputs_with_pos, inputs
|
return inputs_with_pos, inputs
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
pos: Optional[torch.Tensor] = None,
|
||||||
|
network_input_is_1d: bool = True,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
):
|
||||||
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
|
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
|
||||||
|
|
||||||
inputs, inputs_without_pos = self._build_network_inputs(inputs)
|
inputs, inputs_without_pos = self._build_network_inputs(inputs)
|
||||||
@@ -3391,7 +3450,11 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
|
|||||||
return common_channel_size
|
return common_channel_size
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True
|
self,
|
||||||
|
inputs: Mapping[str, torch.Tensor],
|
||||||
|
pos: Optional[torch.Tensor] = None,
|
||||||
|
network_input_is_1d: bool = True,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> PreprocessorOutputType:
|
) -> PreprocessorOutputType:
|
||||||
padded = {}
|
padded = {}
|
||||||
modality_sizes = {}
|
modality_sizes = {}
|
||||||
|
|||||||
@@ -1031,3 +1031,23 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384})
|
||||||
|
model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device)
|
||||||
|
input_mask = None
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# verify logits
|
||||||
|
expected_shape = torch.Size((1, model.config.num_labels))
|
||||||
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user