addressing the issue #34611 to make FlaxDinov2 compatible with any batch size (#35138)

fixed the batch_size error, all tests are passing

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
MAHIR DAIYAN
2025-02-25 18:44:44 +08:00
committed by GitHub
parent 3a02fe56c2
commit d80d52b007
2 changed files with 23 additions and 13 deletions

View File

@@ -185,9 +185,11 @@ class FlaxDinov2Embeddings(nn.Module):
antialias=False,
)
patch_pos_embed = patch_pos_embed.astype(target_dtype)
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((hidden_states.shape[0], -1, dim))
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim))
patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1))
class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1))
return jnp.concatenate((class_pos_embed[jnp.newaxis, :], patch_pos_embed), axis=1)
return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1)
def __call__(self, pixel_values, deterministic=True):
batch_size = pixel_values.shape[0]
@@ -778,7 +780,7 @@ FLAX_VISION_CLASSIFICATION_DOCSTRING = """
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True)
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)

View File

@@ -202,7 +202,7 @@ class FlaxDionv2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
return [image, image]
@require_vision
@@ -224,18 +224,25 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
outputs = model(pixel_values=pixel_values)
# verify the logits
expected_shape = (1, 257, 768)
expected_shape = (2, 257, 768)
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = np.array(
[
[
[-2.1629121, -0.46566057, 1.0925977],
[-3.5971704, -1.0283585, -1.1780515],
[-2.900407, 1.1334689, -0.74357724],
],
[
[-2.1629121, -0.46566057, 1.0925977],
[-3.5971704, -1.0283585, -1.1780515],
[-2.900407, 1.1334689, -0.74357724],
],
]
)
self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
self.assertTrue(np.allclose(outputs.last_hidden_state[:2, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_image_classification_head_imagenet_1k(self):
@@ -252,12 +259,13 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
logits = outputs.logits
# verify the logits
expected_shape = (1, 1000)
expected_shape = (2, 1000)
self.assertEqual(logits.shape, expected_shape)
expected_slice = np.array([-2.1776447, 0.36716992, 0.13870952])
expected_slice = np.array([[-2.1776447, 0.36716992, 0.13870952], [-2.1776447, 0.36716992, 0.13870952]])
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
self.assertTrue(np.allclose(logits[:2, :3], expected_slice, atol=1e-3))
expected_class_idx = 281
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
self.assertEqual(logits[0].argmax(-1).item(), expected_class_idx)
self.assertEqual(logits[1].argmax(-1).item(), expected_class_idx)