diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 3286b8912c..60cb52c7e8 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -676,8 +676,10 @@ class SuperGlueForKeypointMatching(SuperGluePreTrainedModel): if mask is not None: mask = mask.reshape(batch_size, 2, num_keypoints) - mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints) - scores = scores.masked_fill(mask0 == 0, -1e9) + mask0 = mask[:, 0].unsqueeze(2) + mask1 = mask[:, 1].unsqueeze(1) + mask = torch.logical_and(mask0, mask1) + scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min) # Run the optimal transport. scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations) diff --git a/tests/models/superglue/test_modeling_superglue.py b/tests/models/superglue/test_modeling_superglue.py index 84fbb6c4bf..daafbef62b 100644 --- a/tests/models/superglue/test_modeling_superglue.py +++ b/tests/models/superglue/test_modeling_superglue.py @@ -423,3 +423,5 @@ class SuperGlueModelIntegrationTest(unittest.TestCase): torch.sum(~torch.isclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)) < 4 ) self.assertTrue(torch.sum(predicted_matches_values != expected_matches_values) < 4) + self.assertTrue(torch.all(outputs.matches[0, 1] < torch.sum(outputs.mask[0, 0]))) + self.assertTrue(torch.all(outputs.matches[0, 0] < torch.sum(outputs.mask[0, 1])))