mps: add isin_mps_friendly, a wrapper function for torch.isin (#33099)
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..cache_utils import DynamicCache
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
||||
|
||||
|
||||
@@ -335,7 +336,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
|
||||
# accept eos and the rest as valid, thus not stopping generation after "eos"
|
||||
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1
|
||||
mask = torch.isin(chosen_ids, self.eos_token_id)
|
||||
mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
|
||||
match_indices_eos = torch.nonzero(mask)
|
||||
if match_indices_eos.numel() > 0:
|
||||
first_eos_index = match_indices_eos[0].item()
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from ..utils import add_start_docstrings
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
@@ -159,7 +160,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||
eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id)
|
||||
scores_processed = scores.clone()
|
||||
if input_ids.shape[-1] < self.min_length:
|
||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||
@@ -231,7 +232,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||
scores_processed = scores.clone()
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||
eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id)
|
||||
if new_tokens_length < self.min_new_tokens:
|
||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||
|
||||
@@ -1795,7 +1796,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
|
||||
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.begin_suppress_tokens)
|
||||
scores_processed = scores
|
||||
if input_ids.shape[-1] == self.begin_index:
|
||||
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||
@@ -1838,7 +1839,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
|
||||
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
|
||||
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||
return scores
|
||||
|
||||
|
||||
@@ -9,8 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from ..utils import add_start_docstrings, logging
|
||||
|
||||
@@ -457,19 +456,7 @@ class EosTokenCriteria(StoppingCriteria):
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
||||
if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
# TODO: remove this workaround when we stop supporting torch<=2.3
|
||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||
is_done = (
|
||||
input_ids[:, -1]
|
||||
.tile(self.eos_token_id.shape[0], 1)
|
||||
.eq(self.eos_token_id.unsqueeze(1))
|
||||
.sum(dim=0)
|
||||
.bool()
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
|
||||
is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id)
|
||||
return is_done
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ from ..models.auto import (
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from ..tokenization_utils import ExtensionsTrie
|
||||
from ..utils import (
|
||||
ModelOutput,
|
||||
@@ -472,18 +472,11 @@ class GenerationMixin:
|
||||
if not is_input_ids:
|
||||
return default_attention_mask
|
||||
|
||||
# Otherwise we have may have information -> try to infer the attention mask
|
||||
if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
# mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
|
||||
raise ValueError(
|
||||
"Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
|
||||
)
|
||||
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
||||
torch.isin(elements=inputs, test_elements=pad_token_id).any()
|
||||
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
||||
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||
)
|
||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||
@@ -1660,7 +1653,7 @@ class GenerationMixin:
|
||||
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
|
||||
if (
|
||||
eos_token_tensor is not None
|
||||
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||
):
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning_once(
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...modeling_outputs import (
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||
from ...pytorch_utils import Conv1D
|
||||
from ...pytorch_utils import Conv1D, isin_mps_friendly
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@@ -132,7 +132,7 @@ def _pad_extra_bos_eos_tokens(
|
||||
)
|
||||
for i, each_input_id in enumerate(input_ids):
|
||||
# locate where the valid tokens end and then add the eos token
|
||||
if torch.isin(each_input_id, pad_token_id).sum():
|
||||
if isin_mps_friendly(each_input_id, pad_token_id).sum():
|
||||
pos = torch.where(each_input_id == pad_token_id)[0].min()
|
||||
modified_input_ids[i] = torch.concatenate(
|
||||
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
|
||||
|
||||
@@ -303,3 +303,24 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
|
||||
unique_id = storage_ptr(tensor)
|
||||
|
||||
return tensor.device, unique_id, storage_size(tensor)
|
||||
|
||||
|
||||
def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor:
|
||||
"""
|
||||
Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting
|
||||
torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||
|
||||
Args:
|
||||
elements (`torch.Tensor`): Input elements
|
||||
test_elements (`torch.Tensor`): The elements to check against.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
|
||||
and False otherwise
|
||||
"""
|
||||
|
||||
if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
|
||||
else:
|
||||
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
@@ -106,6 +106,7 @@ if is_torch_available():
|
||||
dtype_byte_size,
|
||||
shard_checkpoint,
|
||||
)
|
||||
from transformers.pytorch_utils import isin_mps_friendly
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
@@ -1698,6 +1699,22 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertIn("beta_param", missing_keys)
|
||||
self.assertIn("bias_param", unexpected_keys)
|
||||
|
||||
def test_isin_mps_friendly(self):
|
||||
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
|
||||
random_ids = torch.randint(0, 100, (100,))
|
||||
# We can match against an interger
|
||||
random_test_integer = torch.randint(0, 100, (1,)).item()
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
|
||||
)
|
||||
)
|
||||
# We can match against an tensor of integers
|
||||
random_test_tensor = torch.randint(0, 100, (10,))
|
||||
self.assertTrue(
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user