From 9ebfda3263fcdc4e05fd87fad1aadc8a08294608 Mon Sep 17 00:00:00 2001 From: "Chulhwa (Evan) Han" Date: Wed, 26 Feb 2025 04:31:24 +0900 Subject: [PATCH] Fixed VitDet for non-squre Images (#35969) * size tuple * delete original input_size * use zip * process the other case * Update src/transformers/models/vitdet/modeling_vitdet.py Co-authored-by: Pavel Iakubovskii * [VITDET] Test non-square image * [Fix] Make Quality * make fix style * Update src/transformers/models/vitdet/modeling_vitdet.py --------- Co-authored-by: Pavel Iakubovskii --- .../models/vitdet/modeling_vitdet.py | 8 +++++- tests/models/vitdet/test_modeling_vitdet.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 9bd7ca2ff1..9585c295e1 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -456,8 +456,14 @@ class VitDetLayer(nn.Module): super().__init__() dim = config.hidden_size - input_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + image_size = config.image_size + image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size) + + patch_size = config.patch_size + patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size) + + input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = VitDetAttention( config, input_size=input_size if window_size == 0 else (window_size, window_size) diff --git a/tests/models/vitdet/test_modeling_vitdet.py b/tests/models/vitdet/test_modeling_vitdet.py index 2c46b60f7e..4b5ac0f337 100644 --- a/tests/models/vitdet/test_modeling_vitdet.py +++ b/tests/models/vitdet/test_modeling_vitdet.py @@ -290,6 +290,31 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_model_from_pretrained(self): pass + def test_non_square_image(self): + non_square_image_size = (32, 40) + patch_size = (2, 2) + config = self.model_tester.get_config() + config.image_size = non_square_image_size + config.patch_size = patch_size + + model = VitDetModel(config=config) + model.to(torch_device) + model.eval() + + batch_size = self.model_tester.batch_size + # Create a dummy input tensor with non-square spatial dimensions. + pixel_values = floats_tensor( + [batch_size, config.num_channels, non_square_image_size[0], non_square_image_size[1]] + ) + + result = model(pixel_values) + + expected_height = non_square_image_size[0] / patch_size[0] + expected_width = non_square_image_size[1] / patch_size[1] + expected_shape = (batch_size, config.hidden_size, expected_height, expected_width) + + self.assertEqual(result.last_hidden_state.shape, expected_shape) + @require_torch class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):