mps: add isin_mps_friendly, a wrapper function for torch.isin (#33099)
This commit is contained in:
@@ -106,6 +106,7 @@ if is_torch_available():
|
||||
dtype_byte_size,
|
||||
shard_checkpoint,
|
||||
)
|
||||
from transformers.pytorch_utils import isin_mps_friendly
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
@@ -1698,6 +1699,22 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertIn("beta_param", missing_keys)
|
||||
self.assertIn("bias_param", unexpected_keys)
|
||||
|
||||
def test_isin_mps_friendly(self):
|
||||
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
|
||||
random_ids = torch.randint(0, 100, (100,))
|
||||
# We can match against an interger
|
||||
random_test_integer = torch.randint(0, 100, (1,)).item()
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
|
||||
)
|
||||
)
|
||||
# We can match against an tensor of integers
|
||||
random_test_tensor = torch.randint(0, 100, (10,))
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user