mps: add isin_mps_friendly, a wrapper function for torch.isin (#33099)

This commit is contained in:
Joao Gante
2024-08-26 15:34:19 +01:00
committed by GitHub
parent 894d421ee5
commit 72d4a3f9c1
7 changed files with 53 additions and 33 deletions

View File

@@ -106,6 +106,7 @@ if is_torch_available():
dtype_byte_size,
shard_checkpoint,
)
from transformers.pytorch_utils import isin_mps_friendly
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
@@ -1698,6 +1699,22 @@ class ModelUtilsTest(TestCasePlus):
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
def test_isin_mps_friendly(self):
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
random_ids = torch.randint(0, 100, (100,))
# We can match against an interger
random_test_integer = torch.randint(0, 100, (1,)).item()
self.assertTrue(
torch.equal(
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
)
)
# We can match against an tensor of integers
random_test_tensor = torch.randint(0, 100, (10,))
self.assertTrue(
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
@slow
@require_torch