MPS: isin_mps_friendly can support 0D tensors (#34538)

* apply fix

* tested

* make fixup
This commit is contained in:
Joao Gante
2024-11-04 16:18:50 +00:00
committed by GitHub
parent 187439c3fa
commit 34927b0f73
2 changed files with 10 additions and 2 deletions

View File

@@ -1711,7 +1711,12 @@ class ModelUtilsTest(TestCasePlus):
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
)
)
# We can match against an tensor of integers
# We can match against an 0D tensor
random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
self.assertTrue(
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
# We can match against an 1D tensor (with many items)
random_test_tensor = torch.randint(0, 100, (10,))
self.assertTrue(
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))