addressing the issue #34611 to make FlaxDinov2 compatible with any batch size (#35138)

fixed the batch_size error, all tests are passing

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
MAHIR DAIYAN
2025-02-25 18:44:44 +08:00
committed by GitHub
parent 3a02fe56c2
commit d80d52b007
2 changed files with 23 additions and 13 deletions

View File

@@ -202,7 +202,7 @@ class FlaxDionv2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
return [image, image]
@require_vision
@@ -224,18 +224,25 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
outputs = model(pixel_values=pixel_values)
# verify the logits
expected_shape = (1, 257, 768)
expected_shape = (2, 257, 768)
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = np.array(
[
[-2.1629121, -0.46566057, 1.0925977],
[-3.5971704, -1.0283585, -1.1780515],
[-2.900407, 1.1334689, -0.74357724],
[
[-2.1629121, -0.46566057, 1.0925977],
[-3.5971704, -1.0283585, -1.1780515],
[-2.900407, 1.1334689, -0.74357724],
],
[
[-2.1629121, -0.46566057, 1.0925977],
[-3.5971704, -1.0283585, -1.1780515],
[-2.900407, 1.1334689, -0.74357724],
],
]
)
self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
self.assertTrue(np.allclose(outputs.last_hidden_state[:2, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_image_classification_head_imagenet_1k(self):
@@ -252,12 +259,13 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
logits = outputs.logits
# verify the logits
expected_shape = (1, 1000)
expected_shape = (2, 1000)
self.assertEqual(logits.shape, expected_shape)
expected_slice = np.array([-2.1776447, 0.36716992, 0.13870952])
expected_slice = np.array([[-2.1776447, 0.36716992, 0.13870952], [-2.1776447, 0.36716992, 0.13870952]])
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
self.assertTrue(np.allclose(logits[:2, :3], expected_slice, atol=1e-3))
expected_class_idx = 281
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
self.assertEqual(logits[0].argmax(-1).item(), expected_class_idx)
self.assertEqual(logits[1].argmax(-1).item(), expected_class_idx)