fixed the batch_size error, all tests are passing Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -185,9 +185,11 @@ class FlaxDinov2Embeddings(nn.Module):
|
||||
antialias=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.astype(target_dtype)
|
||||
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((hidden_states.shape[0], -1, dim))
|
||||
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim))
|
||||
patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1))
|
||||
class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1))
|
||||
|
||||
return jnp.concatenate((class_pos_embed[jnp.newaxis, :], patch_pos_embed), axis=1)
|
||||
return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1)
|
||||
|
||||
def __call__(self, pixel_values, deterministic=True):
|
||||
batch_size = pixel_values.shape[0]
|
||||
@@ -778,7 +780,7 @@ FLAX_VISION_CLASSIFICATION_DOCSTRING = """
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
|
||||
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
|
||||
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True)
|
||||
|
||||
>>> inputs = image_processor(images=image, return_tensors="np")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
@@ -202,7 +202,7 @@ class FlaxDionv2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
return [image, image]
|
||||
|
||||
|
||||
@require_vision
|
||||
@@ -224,18 +224,25 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(pixel_values=pixel_values)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = (1, 257, 768)
|
||||
expected_shape = (2, 257, 768)
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
[-2.1629121, -0.46566057, 1.0925977],
|
||||
[-3.5971704, -1.0283585, -1.1780515],
|
||||
[-2.900407, 1.1334689, -0.74357724],
|
||||
[
|
||||
[-2.1629121, -0.46566057, 1.0925977],
|
||||
[-3.5971704, -1.0283585, -1.1780515],
|
||||
[-2.900407, 1.1334689, -0.74357724],
|
||||
],
|
||||
[
|
||||
[-2.1629121, -0.46566057, 1.0925977],
|
||||
[-3.5971704, -1.0283585, -1.1780515],
|
||||
[-2.900407, 1.1334689, -0.74357724],
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
self.assertTrue(np.allclose(outputs.last_hidden_state[:2, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head_imagenet_1k(self):
|
||||
@@ -252,12 +259,13 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = (1, 1000)
|
||||
expected_shape = (2, 1000)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = np.array([-2.1776447, 0.36716992, 0.13870952])
|
||||
expected_slice = np.array([[-2.1776447, 0.36716992, 0.13870952], [-2.1776447, 0.36716992, 0.13870952]])
|
||||
|
||||
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
|
||||
self.assertTrue(np.allclose(logits[:2, :3], expected_slice, atol=1e-3))
|
||||
|
||||
expected_class_idx = 281
|
||||
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
|
||||
self.assertEqual(logits[0].argmax(-1).item(), expected_class_idx)
|
||||
self.assertEqual(logits[1].argmax(-1).item(), expected_class_idx)
|
||||
|
||||
Reference in New Issue
Block a user