Fixed VitDet for non-squre Images (#35969)

* size tuple

* delete original input_size

* use zip

* process the other case

* Update src/transformers/models/vitdet/modeling_vitdet.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* [VITDET] Test non-square image

* [Fix] Make Quality

* make fix style

* Update src/transformers/models/vitdet/modeling_vitdet.py

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Chulhwa (Evan) Han
2025-02-26 04:31:24 +09:00
committed by GitHub
parent cbe0ea59f3
commit 9ebfda3263
2 changed files with 32 additions and 1 deletions

View File

@@ -456,8 +456,14 @@ class VitDetLayer(nn.Module):
super().__init__() super().__init__()
dim = config.hidden_size dim = config.hidden_size
input_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
image_size = config.image_size
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
patch_size = config.patch_size
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = VitDetAttention( self.attention = VitDetAttention(
config, input_size=input_size if window_size == 0 else (window_size, window_size) config, input_size=input_size if window_size == 0 else (window_size, window_size)

View File

@@ -290,6 +290,31 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
pass pass
def test_non_square_image(self):
non_square_image_size = (32, 40)
patch_size = (2, 2)
config = self.model_tester.get_config()
config.image_size = non_square_image_size
config.patch_size = patch_size
model = VitDetModel(config=config)
model.to(torch_device)
model.eval()
batch_size = self.model_tester.batch_size
# Create a dummy input tensor with non-square spatial dimensions.
pixel_values = floats_tensor(
[batch_size, config.num_channels, non_square_image_size[0], non_square_image_size[1]]
)
result = model(pixel_values)
expected_height = non_square_image_size[0] / patch_size[0]
expected_width = non_square_image_size[1] / patch_size[1]
expected_shape = (batch_size, config.hidden_size, expected_height, expected_width)
self.assertEqual(result.last_hidden_state.shape, expected_shape)
@require_torch @require_torch
class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin): class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):