fixed the batch_size error, all tests are passing Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -202,7 +202,7 @@ class FlaxDionv2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
return [image, image]
|
||||
|
||||
|
||||
@require_vision
|
||||
@@ -224,18 +224,25 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(pixel_values=pixel_values)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = (1, 257, 768)
|
||||
expected_shape = (2, 257, 768)
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
[-2.1629121, -0.46566057, 1.0925977],
|
||||
[-3.5971704, -1.0283585, -1.1780515],
|
||||
[-2.900407, 1.1334689, -0.74357724],
|
||||
[
|
||||
[-2.1629121, -0.46566057, 1.0925977],
|
||||
[-3.5971704, -1.0283585, -1.1780515],
|
||||
[-2.900407, 1.1334689, -0.74357724],
|
||||
],
|
||||
[
|
||||
[-2.1629121, -0.46566057, 1.0925977],
|
||||
[-3.5971704, -1.0283585, -1.1780515],
|
||||
[-2.900407, 1.1334689, -0.74357724],
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
self.assertTrue(np.allclose(outputs.last_hidden_state[:2, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head_imagenet_1k(self):
|
||||
@@ -252,12 +259,13 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = (1, 1000)
|
||||
expected_shape = (2, 1000)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = np.array([-2.1776447, 0.36716992, 0.13870952])
|
||||
expected_slice = np.array([[-2.1776447, 0.36716992, 0.13870952], [-2.1776447, 0.36716992, 0.13870952]])
|
||||
|
||||
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
|
||||
self.assertTrue(np.allclose(logits[:2, :3], expected_slice, atol=1e-3))
|
||||
|
||||
expected_class_idx = 281
|
||||
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
|
||||
self.assertEqual(logits[0].argmax(-1).item(), expected_class_idx)
|
||||
self.assertEqual(logits[1].argmax(-1).item(), expected_class_idx)
|
||||
|
||||
Reference in New Issue
Block a user