fixed the batch_size error, all tests are passing Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -185,9 +185,11 @@ class FlaxDinov2Embeddings(nn.Module):
|
|||||||
antialias=False,
|
antialias=False,
|
||||||
)
|
)
|
||||||
patch_pos_embed = patch_pos_embed.astype(target_dtype)
|
patch_pos_embed = patch_pos_embed.astype(target_dtype)
|
||||||
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((hidden_states.shape[0], -1, dim))
|
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim))
|
||||||
|
patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1))
|
||||||
|
class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1))
|
||||||
|
|
||||||
return jnp.concatenate((class_pos_embed[jnp.newaxis, :], patch_pos_embed), axis=1)
|
return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1)
|
||||||
|
|
||||||
def __call__(self, pixel_values, deterministic=True):
|
def __call__(self, pixel_values, deterministic=True):
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
@@ -778,7 +780,7 @@ FLAX_VISION_CLASSIFICATION_DOCSTRING = """
|
|||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
|
||||||
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
|
>>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True)
|
||||||
|
|
||||||
>>> inputs = image_processor(images=image, return_tensors="np")
|
>>> inputs = image_processor(images=image, return_tensors="np")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class FlaxDionv2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
def prepare_img():
|
def prepare_img():
|
||||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
return image
|
return [image, image]
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@@ -224,18 +224,25 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(pixel_values=pixel_values)
|
outputs = model(pixel_values=pixel_values)
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
expected_shape = (1, 257, 768)
|
expected_shape = (2, 257, 768)
|
||||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = np.array(
|
expected_slice = np.array(
|
||||||
[
|
[
|
||||||
[-2.1629121, -0.46566057, 1.0925977],
|
[
|
||||||
[-3.5971704, -1.0283585, -1.1780515],
|
[-2.1629121, -0.46566057, 1.0925977],
|
||||||
[-2.900407, 1.1334689, -0.74357724],
|
[-3.5971704, -1.0283585, -1.1780515],
|
||||||
|
[-2.900407, 1.1334689, -0.74357724],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-2.1629121, -0.46566057, 1.0925977],
|
||||||
|
[-3.5971704, -1.0283585, -1.1780515],
|
||||||
|
[-2.900407, 1.1334689, -0.74357724],
|
||||||
|
],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(np.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(np.allclose(outputs.last_hidden_state[:2, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_image_classification_head_imagenet_1k(self):
|
def test_inference_image_classification_head_imagenet_1k(self):
|
||||||
@@ -252,12 +259,13 @@ class FlaxDinov2ModelIntegrationTest(unittest.TestCase):
|
|||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
expected_shape = (1, 1000)
|
expected_shape = (2, 1000)
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = np.array([-2.1776447, 0.36716992, 0.13870952])
|
expected_slice = np.array([[-2.1776447, 0.36716992, 0.13870952], [-2.1776447, 0.36716992, 0.13870952]])
|
||||||
|
|
||||||
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
|
self.assertTrue(np.allclose(logits[:2, :3], expected_slice, atol=1e-3))
|
||||||
|
|
||||||
expected_class_idx = 281
|
expected_class_idx = 281
|
||||||
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
|
self.assertEqual(logits[0].argmax(-1).item(), expected_class_idx)
|
||||||
|
self.assertEqual(logits[1].argmax(-1).item(), expected_class_idx)
|
||||||
|
|||||||
Reference in New Issue
Block a user