Enable dynamic resolution input for Swin Transformer and variants (#30656)
* add interpolation of positional encoding support to swin * add style changes * use default image processor and make size a dictionary Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove logits testing Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor image size validation logic when interpolation is disabled Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove asserts in modeling Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add dynamic resolution input support to swinv2 * change size to ensure interpolation encoding path is triggered * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set interpolate_pos_encoding default value to False * add dynamic resolution input to donut swin * add dynamic resolution input to maskformer swin --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -493,6 +493,26 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# Swin models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions.
|
||||
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 256, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
||||
@@ -485,6 +485,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions.
|
||||
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 256, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
||||
Reference in New Issue
Block a user