MPS: isin_mps_friendly can support 0D tensors (#34538)
* apply fix * tested * make fixup
This commit is contained in:
@@ -1711,7 +1711,12 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
|
||||
)
|
||||
)
|
||||
# We can match against an tensor of integers
|
||||
# We can match against an 0D tensor
|
||||
random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
# We can match against an 1D tensor (with many items)
|
||||
random_test_tensor = torch.randint(0, 100, (10,))
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
|
||||
Reference in New Issue
Block a user