diff --git a/src/transformers/models/dinov2/modeling_flax_dinov2.py b/src/transformers/models/dinov2/modeling_flax_dinov2.py index cf2a6e04c4..8093b3a0b7 100644 --- a/src/transformers/models/dinov2/modeling_flax_dinov2.py +++ b/src/transformers/models/dinov2/modeling_flax_dinov2.py @@ -185,9 +185,11 @@ class FlaxDinov2Embeddings(nn.Module): antialias=False, ) patch_pos_embed = patch_pos_embed.astype(target_dtype) - patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((hidden_states.shape[0], -1, dim)) + patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim)) + patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1)) + class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1)) - return jnp.concatenate((class_pos_embed[jnp.newaxis, :], patch_pos_embed), axis=1) + return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1) def __call__(self, pixel_values, deterministic=True): batch_size = pixel_values.shape[0] @@ -778,7 +780,7 @@ FLAX_VISION_CLASSIFICATION_DOCSTRING = """ >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") - >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") + >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True) >>> inputs = image_processor(images=image, return_tensors="np") >>> outputs = model(**inputs) diff --git a/tests/models/dinov2/test_modeling_flax_dinov2.py b/tests/models/dinov2/test_modeling_flax_dinov2.py index 68510bb505..09ce20611a 100644 --- a/tests/models/dinov2/test_modeling_flax_dinov2.py +++ b/tests/models/dinov2/test_modeling_flax_dinov2.py @@ -202,7 +202,7 @@ class FlaxDionv2ModelTest(FlaxModelTesterMixin, unittest.TestCase): # We will verify our results on an image of cute cats def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") - return image + return [image, image] @require_vision @@ -224,18 +224,25 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase): outputs = model(pixel_values=pixel_values) # verify the logits - expected_shape = (1, 257, 768) + expected_shape = (2, 257, 768) self.assertEqual(outputs.last_hidden_state.shape, expected_shape) expected_slice = np.array( [ - [-2.1629121, -0.46566057, 1.0925977], - [-3.5971704, -1.0283585, -1.1780515], - [-2.900407, 1.1334689, -0.74357724], + [ + [-2.1629121, -0.46566057, 1.0925977], + [-3.5971704, -1.0283585, -1.1780515], + [-2.900407, 1.1334689, -0.74357724], + ], + [ + [-2.1629121, -0.46566057, 1.0925977], + [-3.5971704, -1.0283585, -1.1780515], + [-2.900407, 1.1334689, -0.74357724], + ], ] ) - self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + self.assertTrue(np.allclose(outputs.last_hidden_state[:2, :3, :3], expected_slice, atol=1e-4)) @slow def test_inference_image_classification_head_imagenet_1k(self): @@ -252,12 +259,13 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase): logits = outputs.logits # verify the logits - expected_shape = (1, 1000) + expected_shape = (2, 1000) self.assertEqual(logits.shape, expected_shape) - expected_slice = np.array([-2.1776447, 0.36716992, 0.13870952]) + expected_slice = np.array([[-2.1776447, 0.36716992, 0.13870952], [-2.1776447, 0.36716992, 0.13870952]]) - self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4)) + self.assertTrue(np.allclose(logits[:2, :3], expected_slice, atol=1e-3)) expected_class_idx = 281 - self.assertEqual(logits.argmax(-1).item(), expected_class_idx) + self.assertEqual(logits[0].argmax(-1).item(), expected_class_idx) + self.assertEqual(logits[1].argmax(-1).item(), expected_class_idx)