Tests: replace torch.testing.assert_allclose by torch.testing.assert_close (#29915)
* replace torch.testing.assert_allclose by torch.testing.assert_close * missing atol rtol
This commit is contained in:
@@ -725,7 +725,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
iou_scores = outputs.iou_scores.cpu()
|
||||
self.assertTrue(iou_scores.shape == (1, 2, 3))
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
|
||||
)
|
||||
|
||||
@@ -753,7 +753,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
iou_scores = outputs.iou_scores.cpu()
|
||||
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
||||
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
||||
torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_dummy_pipeline_generation(self):
|
||||
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
|
||||
|
||||
Reference in New Issue
Block a user