Fixed VitDet for non-squre Images (#35969)
* size tuple * delete original input_size * use zip * process the other case * Update src/transformers/models/vitdet/modeling_vitdet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * [VITDET] Test non-square image * [Fix] Make Quality * make fix style * Update src/transformers/models/vitdet/modeling_vitdet.py --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
committed by
GitHub
parent
cbe0ea59f3
commit
9ebfda3263
@@ -456,8 +456,14 @@ class VitDetLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
dim = config.hidden_size
|
dim = config.hidden_size
|
||||||
input_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
|
|
||||||
|
|
||||||
|
image_size = config.image_size
|
||||||
|
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
|
||||||
|
|
||||||
|
patch_size = config.patch_size
|
||||||
|
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
|
||||||
|
|
||||||
|
input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||||||
self.attention = VitDetAttention(
|
self.attention = VitDetAttention(
|
||||||
config, input_size=input_size if window_size == 0 else (window_size, window_size)
|
config, input_size=input_size if window_size == 0 else (window_size, window_size)
|
||||||
|
|||||||
@@ -290,6 +290,31 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_non_square_image(self):
|
||||||
|
non_square_image_size = (32, 40)
|
||||||
|
patch_size = (2, 2)
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
config.image_size = non_square_image_size
|
||||||
|
config.patch_size = patch_size
|
||||||
|
|
||||||
|
model = VitDetModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size = self.model_tester.batch_size
|
||||||
|
# Create a dummy input tensor with non-square spatial dimensions.
|
||||||
|
pixel_values = floats_tensor(
|
||||||
|
[batch_size, config.num_channels, non_square_image_size[0], non_square_image_size[1]]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
expected_height = non_square_image_size[0] / patch_size[0]
|
||||||
|
expected_width = non_square_image_size[1] / patch_size[1]
|
||||||
|
expected_shape = (batch_size, config.hidden_size, expected_height, expected_width)
|
||||||
|
|
||||||
|
self.assertEqual(result.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||||
|
|||||||
Reference in New Issue
Block a user