Tests: replace torch.testing.assert_allclose by torch.testing.assert_close (#29915)

* replace torch.testing.assert_allclose by torch.testing.assert_close

* missing atol rtol
This commit is contained in:
Joao Gante
2024-03-28 09:53:31 +00:00
committed by GitHub
parent 7c19fafe44
commit 248d5d23a2
7 changed files with 30 additions and 34 deletions

View File

@@ -725,7 +725,7 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 2, 3))
torch.testing.assert_allclose(
torch.testing.assert_close(
iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
)
@@ -753,7 +753,7 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 3, 3))
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
def test_dummy_pipeline_generation(self):
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)