From 72d4a3f9c159685baae2274392562c9db97f4b64 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 26 Aug 2024 15:34:19 +0100 Subject: [PATCH] mps: add `isin_mps_friendly`, a wrapper function for `torch.isin` (#33099) --- .../generation/candidate_generator.py | 3 ++- src/transformers/generation/logits_process.py | 9 ++++---- .../generation/stopping_criteria.py | 17 ++------------- src/transformers/generation/utils.py | 15 ++++--------- src/transformers/models/clvp/modeling_clvp.py | 4 ++-- src/transformers/pytorch_utils.py | 21 +++++++++++++++++++ tests/utils/test_modeling_utils.py | 17 +++++++++++++++ 7 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bf55ae3e2b..7e4096c0aa 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import torch from ..cache_utils import DynamicCache +from ..pytorch_utils import isin_mps_friendly from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor @@ -335,7 +336,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator): # remove remaining candidate ids if an "eos" token is found, otherwise the target model may # accept eos and the rest as valid, thus not stopping generation after "eos" # NOTE: below code is written based on the fact that assisted decoding supports only bs=1 - mask = torch.isin(chosen_ids, self.eos_token_id) + mask = isin_mps_friendly(chosen_ids, self.eos_token_id) match_indices_eos = torch.nonzero(mask) if match_indices_eos.numel() > 0: first_eos_index = match_indices_eos[0].item() diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e9ba456068..c586a97459 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -20,6 +20,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch +from ..pytorch_utils import isin_mps_friendly from ..utils import add_start_docstrings from ..utils.logging import get_logger @@ -159,7 +160,7 @@ class MinLengthLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) + eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id) scores_processed = scores.clone() if input_ids.shape[-1] < self.min_length: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -231,7 +232,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip scores_processed = scores.clone() vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) + eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id) if new_tokens_length < self.min_new_tokens: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -1795,7 +1796,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens) + suppress_token_mask = isin_mps_friendly(vocab_tensor, self.begin_suppress_tokens) scores_processed = scores if input_ids.shape[-1] == self.begin_index: scores_processed = torch.where(suppress_token_mask, -float("inf"), scores) @@ -1838,7 +1839,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens) + suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens) scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 961b6d6f5e..b8d6540ca2 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -9,8 +9,7 @@ import numpy as np import torch from torch.nn import functional as F -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 - +from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import add_start_docstrings, logging @@ -457,19 +456,7 @@ class EosTokenCriteria(StoppingCriteria): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: self.eos_token_id = self.eos_token_id.to(input_ids.device) - if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: - # TODO: remove this workaround when we stop supporting torch<=2.3 - # https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 - is_done = ( - input_ids[:, -1] - .tile(self.eos_token_id.shape[0], 1) - .eq(self.eos_token_id.unsqueeze(1)) - .sum(dim=0) - .bool() - .squeeze() - ) - else: - is_done = torch.isin(input_ids[:, -1], self.eos_token_id) + is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id) return is_done diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e3c70ac109..0d2baea6d8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -47,7 +47,7 @@ from ..models.auto import ( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..pytorch_utils import is_torch_greater_or_equal_than_2_4 +from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie from ..utils import ( ModelOutput, @@ -472,18 +472,11 @@ class GenerationMixin: if not is_input_ids: return default_attention_mask - # Otherwise we have may have information -> try to infer the attention mask - if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: - # mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764) - raise ValueError( - "Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4" - ) - is_pad_token_in_inputs = (pad_token_id is not None) and ( - torch.isin(elements=inputs, test_elements=pad_token_id).any() + isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any() ) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( - torch.isin(elements=eos_token_id, test_elements=pad_token_id).any() + isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any() ) can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id attention_mask_from_padding = inputs.ne(pad_token_id).long() @@ -1660,7 +1653,7 @@ class GenerationMixin: if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow if ( eos_token_tensor is not None - and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any() ): if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: logger.warning_once( diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 4db7f42517..b6d025a0b8 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -35,7 +35,7 @@ from ...modeling_outputs import ( CausalLMOutputWithCrossAttentions, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import Conv1D +from ...pytorch_utils import Conv1D, isin_mps_friendly from ...utils import ( ModelOutput, add_start_docstrings, @@ -132,7 +132,7 @@ def _pad_extra_bos_eos_tokens( ) for i, each_input_id in enumerate(input_ids): # locate where the valid tokens end and then add the eos token - if torch.isin(each_input_id, pad_token_id).sum(): + if isin_mps_friendly(each_input_id, pad_token_id).sum(): pos = torch.where(each_input_id == pad_token_id)[0].min() modified_input_ids[i] = torch.concatenate( [each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]] diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 4c74a04d4f..8c1bd21fb2 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -303,3 +303,24 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: unique_id = storage_ptr(tensor) return tensor.device, unique_id, storage_size(tensor) + + +def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor: + """ + Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting + torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 + + Args: + elements (`torch.Tensor`): Input elements + test_elements (`torch.Tensor`): The elements to check against. + + Returns: + `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements` + and False otherwise + """ + + if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() + else: + # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 + return torch.isin(elements, test_elements) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 238a9a1fe4..f78285fdb9 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -106,6 +106,7 @@ if is_torch_available(): dtype_byte_size, shard_checkpoint, ) + from transformers.pytorch_utils import isin_mps_friendly # Fake pretrained models for tests class BaseModel(PreTrainedModel): @@ -1698,6 +1699,22 @@ class ModelUtilsTest(TestCasePlus): self.assertIn("beta_param", missing_keys) self.assertIn("bias_param", unexpected_keys) + def test_isin_mps_friendly(self): + """tests that our custom `isin_mps_friendly` matches `torch.isin`""" + random_ids = torch.randint(0, 100, (100,)) + # We can match against an interger + random_test_integer = torch.randint(0, 100, (1,)).item() + self.assertTrue( + torch.equal( + torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer) + ) + ) + # We can match against an tensor of integers + random_test_tensor = torch.randint(0, 100, (10,)) + self.assertTrue( + torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) + ) + @slow @require_torch