fix owlvit tests, update docstring examples (#18586)

This commit is contained in:
Alara Dirik
2022-08-11 19:10:25 +03:00
committed by GitHub
parent 05d3a43c59
commit f28f240828
3 changed files with 7 additions and 8 deletions

View File

@@ -733,7 +733,6 @@ def prepare_img():
@require_vision
@require_torch
@unittest.skip("These tests are broken, fix me Alara")
class OwlViTModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
@@ -763,8 +762,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
@slow
@@ -788,7 +786,8 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor(
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))