Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386)

* update attention implementation and improve inference speed

* modular sam_hq + fix integration tests on A10

* fixup

* fix after review

* softmax in correct place

* return attn_weights in sam/sam_hq
This commit is contained in:
Yoni Gozlan
2025-07-18 18:46:27 -04:00
committed by GitHub
parent 541bed22d6
commit 433d2a23d7
4 changed files with 149 additions and 187 deletions

View File

@@ -806,7 +806,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
(None, None): [-13.1695, -14.6201, -14.8989],
("cuda", 8): [-13.1668, -14.6182, -14.8970],
("cuda", 8): [-7.6769, -9.6935, -9.8773],
}
)
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
@@ -831,9 +831,9 @@ class SamHQModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2)
torch.testing.assert_close(scores[-1], torch.tensor(0.9700).to(torch_device), atol=2e-4, rtol=2e-4)
torch.testing.assert_close(
masks, torch.tensor([-9.2033, -8.5505, -7.1361]).to(torch_device), atol=3e-2, rtol=3e-2
)
def test_inference_mask_generation_batched_points_batched_images(self):
@@ -895,7 +895,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
(None, None): [-40.2445, -37.4300, -38.1577],
("cuda", 8): [-40.2351, -37.4334, -38.1526],
("cuda", 8): [-14.1195, -17.2663, -13.7805],
}
)
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)