use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -819,7 +819,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
torch.testing.assert_close(outputs.logits_per_image, expected_logits, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
@@ -851,7 +851,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
expected_logits = torch.tensor([[3.6278, 0.8861]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
torch.testing.assert_close(outputs.logits_per_image, expected_logits, rtol=1e-3, atol=1e-3)
|
||||
|
||||
expected_shape = torch.Size((1, 626, 768))
|
||||
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||
@@ -868,7 +868,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0680, 0.0422, 0.1347], [0.2071, 0.0450, 0.4146], [0.2000, 0.0418, 0.3476]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, rtol=1e-4, atol=1e-4)
|
||||
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
query_image = prepare_img()
|
||||
@@ -913,7 +913,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
[-1.9452, -3.1332, -3.1332, -3.1332],
|
||||
]
|
||||
)
|
||||
self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4))
|
||||
torch.testing.assert_close(model.box_bias[:3, :4], expected_default_box_bias, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Interpolate with any resolution size.
|
||||
processor.image_processor.size = {"height": 1264, "width": 1024}
|
||||
@@ -938,7 +938,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0499, 0.0301, 0.0983], [0.2244, 0.0365, 0.4663], [0.1387, 0.0314, 0.1859]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, rtol=1e-4, atol=1e-4)
|
||||
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
@@ -985,7 +985,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# test post-processing
|
||||
post_processed_output = processor.post_process_grounded_object_detection(outputs)
|
||||
@@ -1028,7 +1028,7 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
torch.testing.assert_close(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user