use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -539,8 +539,10 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
|
||||
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.4515), rtol=2e-4, atol=2e-4)
|
||||
torch.testing.assert_close(
|
||||
masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), rtol=2e-4, atol=2e-4
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -561,9 +563,9 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
|
||||
self.assertTrue(
|
||||
torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.9566), rtol=2e-4, atol=2e-4)
|
||||
torch.testing.assert_close(
|
||||
masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), rtol=2e-4, atol=2e-4
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||
@@ -605,8 +607,8 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625])
|
||||
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
|
||||
torch.testing.assert_close(scores, EXPECTED_SCORES, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(masks, EXPECTED_MASKS, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -632,7 +634,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.7894), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_mask_generation_one_point(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -653,7 +655,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.9675), rtol=1e-4, atol=1e-4)
|
||||
|
||||
# With no label
|
||||
input_points = [[[400, 650]]]
|
||||
@@ -663,7 +665,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.9675), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_mask_generation_two_points(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -684,7 +686,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.9762), rtol=1e-4, atol=1e-4)
|
||||
|
||||
# no labels
|
||||
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
|
||||
@@ -693,7 +695,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.9762), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_mask_generation_two_points_batched(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -714,8 +716,8 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
|
||||
self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))
|
||||
torch.testing.assert_close(scores[0][-1], torch.tensor(0.9762), rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(scores[1][-1], torch.tensor(0.9637), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_mask_generation_one_box(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -733,7 +735,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.7937), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_mask_generation_batched_image_one_point(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -762,7 +764,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores_single = outputs.iou_scores.squeeze()
|
||||
self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
|
||||
torch.testing.assert_close(scores_batched[1, :], scores_single, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_mask_generation_two_points_point_batch(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
@@ -812,7 +814,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
iou_scores = outputs.iou_scores.cpu()
|
||||
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
||||
torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
||||
torch.testing.assert_close(iou_scores, EXPECTED_IOU, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_dummy_pipeline_generation(self):
|
||||
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
|
||||
|
||||
Reference in New Issue
Block a user