Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386)
* update attention implementation and improve inference speed * modular sam_hq + fix integration tests on A10 * fixup * fix after review * softmax in correct place * return attn_weights in sam/sam_hq
This commit is contained in:
@@ -806,7 +806,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [-13.1695, -14.6201, -14.8989],
|
||||
("cuda", 8): [-13.1668, -14.6182, -14.8970],
|
||||
("cuda", 8): [-7.6769, -9.6935, -9.8773],
|
||||
}
|
||||
)
|
||||
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||
@@ -831,9 +831,9 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4))
|
||||
self.assertTrue(
|
||||
torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2)
|
||||
torch.testing.assert_close(scores[-1], torch.tensor(0.9700).to(torch_device), atol=2e-4, rtol=2e-4)
|
||||
torch.testing.assert_close(
|
||||
masks, torch.tensor([-9.2033, -8.5505, -7.1361]).to(torch_device), atol=3e-2, rtol=3e-2
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||
@@ -895,7 +895,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [-40.2445, -37.4300, -38.1577],
|
||||
("cuda", 8): [-40.2351, -37.4334, -38.1526],
|
||||
("cuda", 8): [-14.1195, -17.2663, -13.7805],
|
||||
}
|
||||
)
|
||||
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user