mps: add isin_mps_friendly, a wrapper function for torch.isin (#33099)

This commit is contained in:
Joao Gante
2024-08-26 15:34:19 +01:00
committed by GitHub
parent 894d421ee5
commit 72d4a3f9c1
7 changed files with 53 additions and 33 deletions

View File

@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
@@ -335,7 +336,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
# accept eos and the rest as valid, thus not stopping generation after "eos"
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1
mask = torch.isin(chosen_ids, self.eos_token_id)
mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
match_indices_eos = torch.nonzero(mask)
if match_indices_eos.numel() > 0:
first_eos_index = match_indices_eos[0].item()

View File

@@ -20,6 +20,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from ..pytorch_utils import isin_mps_friendly
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
@@ -159,7 +160,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
@@ -231,7 +232,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
@@ -1795,7 +1796,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
@@ -1838,7 +1839,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores

View File

@@ -9,8 +9,7 @@ import numpy as np
import torch
from torch.nn import functional as F
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import add_start_docstrings, logging
@@ -457,19 +456,7 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# TODO: remove this workaround when we stop supporting torch<=2.3
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]
.tile(self.eos_token_id.shape[0], 1)
.eq(self.eos_token_id.unsqueeze(1))
.sum(dim=0)
.bool()
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id)
return is_done

View File

@@ -47,7 +47,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
ModelOutput,
@@ -472,18 +472,11 @@ class GenerationMixin:
if not is_input_ids:
return default_attention_mask
# Otherwise we have may have information -> try to infer the attention mask
if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
raise ValueError(
"Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
)
is_pad_token_in_inputs = (pad_token_id is not None) and (
torch.isin(elements=inputs, test_elements=pad_token_id).any()
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
@@ -1660,7 +1653,7 @@ class GenerationMixin:
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(

View File

@@ -35,7 +35,7 @@ from ...modeling_outputs import (
CausalLMOutputWithCrossAttentions,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D
from ...pytorch_utils import Conv1D, isin_mps_friendly
from ...utils import (
ModelOutput,
add_start_docstrings,
@@ -132,7 +132,7 @@ def _pad_extra_bos_eos_tokens(
)
for i, each_input_id in enumerate(input_ids):
# locate where the valid tokens end and then add the eos token
if torch.isin(each_input_id, pad_token_id).sum():
if isin_mps_friendly(each_input_id, pad_token_id).sum():
pos = torch.where(each_input_id == pad_token_id)[0].min()
modified_input_ids[i] = torch.concatenate(
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]

View File

@@ -303,3 +303,24 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
unique_id = storage_ptr(tensor)
return tensor.device, unique_id, storage_size(tensor)
def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor:
"""
Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting
torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
Args:
elements (`torch.Tensor`): Input elements
test_elements (`torch.Tensor`): The elements to check against.
Returns:
`torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
and False otherwise
"""
if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
return torch.isin(elements, test_elements)

View File

@@ -106,6 +106,7 @@ if is_torch_available():
dtype_byte_size,
shard_checkpoint,
)
from transformers.pytorch_utils import isin_mps_friendly
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
@@ -1698,6 +1699,22 @@ class ModelUtilsTest(TestCasePlus):
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
def test_isin_mps_friendly(self):
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
random_ids = torch.randint(0, 100, (100,))
# We can match against an interger
random_test_integer = torch.randint(0, 100, (1,)).item()
self.assertTrue(
torch.equal(
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
)
)
# We can match against an tensor of integers
random_test_tensor = torch.randint(0, 100, (10,))
self.assertTrue(
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
@slow
@require_torch