fix owlvit tests, update docstring examples (#18586)
This commit is contained in:
@@ -733,7 +733,6 @@ def prepare_img():
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@unittest.skip("These tests are broken, fix me Alara")
|
||||
class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
@@ -763,8 +762,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
outputs.logits_per_text.shape,
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
|
||||
|
||||
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
@slow
|
||||
@@ -788,7 +786,8 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
|
||||
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user