Fix pytorch integration tests for SAM (#36397)

Fix device in tests
This commit is contained in:
Pavel Iakubovskii
2025-02-25 14:53:34 +00:00
committed by GitHub
parent ca6ebcb9bc
commit fb83befb14

View File

@@ -533,12 +533,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.4515), rtol=2e-4, atol=2e-4)
torch.testing.assert_close(
masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), rtol=2e-4, atol=2e-4
)
torch.testing.assert_close(masks, torch.tensor([-4.1800, -3.4948, -3.4481]), rtol=2e-4, atol=2e-4)
def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-base")
@@ -557,12 +555,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.9566), rtol=2e-4, atol=2e-4)
torch.testing.assert_close(
masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), rtol=2e-4, atol=2e-4
)
torch.testing.assert_close(masks, torch.tensor([-12.7729, -12.3665, -12.6061]), rtol=2e-4, atol=2e-4)
def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-base")
@@ -628,7 +624,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.7894), rtol=1e-4, atol=1e-4)
@@ -650,7 +646,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.9675), rtol=1e-4, atol=1e-4)
# With no label
@@ -660,7 +656,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.9675), rtol=1e-4, atol=1e-4)
def test_inference_mask_generation_two_points(self):
@@ -681,7 +677,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.9762), rtol=1e-4, atol=1e-4)
# no labels
@@ -689,7 +685,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.9762), rtol=1e-4, atol=1e-4)
@@ -711,7 +707,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[0][-1], torch.tensor(0.9762), rtol=1e-4, atol=1e-4)
torch.testing.assert_close(scores[1][-1], torch.tensor(0.9637), rtol=1e-4, atol=1e-4)
@@ -730,7 +726,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
scores = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores[-1], torch.tensor(0.7937), rtol=1e-4, atol=1e-4)
def test_inference_mask_generation_batched_image_one_point(self):
@@ -751,7 +747,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores_batched = outputs.iou_scores.squeeze()
scores_batched = outputs.iou_scores.squeeze().cpu()
input_points = [[[220, 470]]]
@@ -759,7 +755,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores_single = outputs.iou_scores.squeeze()
scores_single = outputs.iou_scores.squeeze().cpu()
torch.testing.assert_close(scores_batched[1, :], scores_single, rtol=1e-4, atol=1e-4)
def test_inference_mask_generation_two_points_point_batch(self):
@@ -797,9 +793,11 @@ class SamModelIntegrationTest(unittest.TestCase):
# fmt: off
input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu()
EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
[0.5996, 0.7661, 0.7937],
[0.5996, 0.7661, 0.7937]]])
EXPECTED_IOU = torch.tensor([[
[0.9773, 0.9881, 0.9522],
[0.5996, 0.7661, 0.7937],
[0.5996, 0.7661, 0.7937],
]])
# fmt: on
input_boxes = input_boxes.unsqueeze(0)