Perceiver interpolate position embedding (#30979)

* add test that currently fails

* test passed

* all perceiver passed

* fixup, style, quality, repo-consistency, all passed

* Apply suggestions from code review: default to False + compute sqrt once only

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix a minor bracket

* replace dim with self._num_channels

* add arguments to the rest preprocessors

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Yixiang Gao
2024-05-24 05:13:58 -05:00
committed by GitHub
parent 5855afd1f3
commit 42d8dd8716
2 changed files with 93 additions and 10 deletions

View File

@@ -1031,3 +1031,23 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384})
model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
model.to(torch_device)
# prepare inputs
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device)
input_mask = None
# forward pass
with torch.no_grad():
outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True)
logits = outputs.logits
# verify logits
expected_shape = torch.Size((1, model.config.num_labels))
self.assertEqual(logits.shape, expected_shape)