Enable dynamic resolution for vivit (#30630)
* feat: enable dynamic resolution for vivit * fix: formatting * remove: print statement for testing * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vivit/test_modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vivit/test_modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vivit/test_modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/vivit/modeling_vivit.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix: style check --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -353,3 +353,26 @@ class VivitModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# Vivit models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)
|
||||
|
||||
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
|
||||
video = prepare_video()
|
||||
inputs = image_processor(
|
||||
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
|
||||
)
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits shape
|
||||
expected_shape = torch.Size((1, 3137, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user