MPS: isin_mps_friendly can support 0D tensors (#34538)
* apply fix * tested * make fixup
This commit is contained in:
@@ -314,7 +314,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
|
||||
Args:
|
||||
elements (`torch.Tensor`): Input elements
|
||||
test_elements (`torch.Tensor`): The elements to check against.
|
||||
test_elements (`torch.Tensor` or `int`): The elements to check against.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
|
||||
@@ -322,6 +322,9 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
||||
"""
|
||||
|
||||
if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
test_elements = torch.tensor(test_elements)
|
||||
if test_elements.ndim == 0:
|
||||
test_elements = test_elements.unsqueeze(0)
|
||||
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
|
||||
else:
|
||||
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
|
||||
|
||||
@@ -1711,7 +1711,12 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
|
||||
)
|
||||
)
|
||||
# We can match against an tensor of integers
|
||||
# We can match against an 0D tensor
|
||||
random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
# We can match against an 1D tensor (with many items)
|
||||
random_test_tensor = torch.randint(0, 100, (10,))
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
|
||||
Reference in New Issue
Block a user