Fixed VitDet for non-squre Images (#35969)
* size tuple * delete original input_size * use zip * process the other case * Update src/transformers/models/vitdet/modeling_vitdet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * [VITDET] Test non-square image * [Fix] Make Quality * make fix style * Update src/transformers/models/vitdet/modeling_vitdet.py --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
committed by
GitHub
parent
cbe0ea59f3
commit
9ebfda3263
@@ -290,6 +290,31 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_model_from_pretrained(self):
|
||||
pass
|
||||
|
||||
def test_non_square_image(self):
|
||||
non_square_image_size = (32, 40)
|
||||
patch_size = (2, 2)
|
||||
config = self.model_tester.get_config()
|
||||
config.image_size = non_square_image_size
|
||||
config.patch_size = patch_size
|
||||
|
||||
model = VitDetModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
batch_size = self.model_tester.batch_size
|
||||
# Create a dummy input tensor with non-square spatial dimensions.
|
||||
pixel_values = floats_tensor(
|
||||
[batch_size, config.num_channels, non_square_image_size[0], non_square_image_size[1]]
|
||||
)
|
||||
|
||||
result = model(pixel_values)
|
||||
|
||||
expected_height = non_square_image_size[0] / patch_size[0]
|
||||
expected_width = non_square_image_size[1] / patch_size[1]
|
||||
expected_shape = (batch_size, config.hidden_size, expected_height, expected_width)
|
||||
|
||||
self.assertEqual(result.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
||||
Reference in New Issue
Block a user