Perceiver interpolate position embedding (#30979)
* add test that currently fails * test passed * all perceiver passed * fixup, style, quality, repo-consistency, all passed * Apply suggestions from code review: default to False + compute sqrt once only Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix a minor bracket * replace dim with self._num_channels * add arguments to the rest preprocessors --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1031,3 +1031,23 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384})
|
||||
model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
|
||||
model.to(torch_device)
|
||||
|
||||
# prepare inputs
|
||||
image = prepare_img()
|
||||
inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
input_mask = None
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = torch.Size((1, model.config.num_labels))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user