mps: add isin_mps_friendly, a wrapper function for torch.isin (#33099)
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..cache_utils import DynamicCache
|
from ..cache_utils import DynamicCache
|
||||||
|
from ..pytorch_utils import isin_mps_friendly
|
||||||
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
@@ -335,7 +336,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
|
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
|
||||||
# accept eos and the rest as valid, thus not stopping generation after "eos"
|
# accept eos and the rest as valid, thus not stopping generation after "eos"
|
||||||
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1
|
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1
|
||||||
mask = torch.isin(chosen_ids, self.eos_token_id)
|
mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
|
||||||
match_indices_eos = torch.nonzero(mask)
|
match_indices_eos = torch.nonzero(mask)
|
||||||
if match_indices_eos.numel() > 0:
|
if match_indices_eos.numel() > 0:
|
||||||
first_eos_index = match_indices_eos[0].item()
|
first_eos_index = match_indices_eos[0].item()
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..pytorch_utils import isin_mps_friendly
|
||||||
from ..utils import add_start_docstrings
|
from ..utils import add_start_docstrings
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
@@ -159,7 +160,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id)
|
||||||
scores_processed = scores.clone()
|
scores_processed = scores.clone()
|
||||||
if input_ids.shape[-1] < self.min_length:
|
if input_ids.shape[-1] < self.min_length:
|
||||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||||
@@ -231,7 +232,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||||
scores_processed = scores.clone()
|
scores_processed = scores.clone()
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
eos_token_mask = isin_mps_friendly(vocab_tensor, self.eos_token_id)
|
||||||
if new_tokens_length < self.min_new_tokens:
|
if new_tokens_length < self.min_new_tokens:
|
||||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||||
|
|
||||||
@@ -1795,7 +1796,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
|||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
|
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.begin_suppress_tokens)
|
||||||
scores_processed = scores
|
scores_processed = scores
|
||||||
if input_ids.shape[-1] == self.begin_index:
|
if input_ids.shape[-1] == self.begin_index:
|
||||||
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
|
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||||
@@ -1838,7 +1839,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
|||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||||
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
|
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
|
||||||
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
from ..pytorch_utils import isin_mps_friendly
|
||||||
|
|
||||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from ..utils import add_start_docstrings, logging
|
from ..utils import add_start_docstrings, logging
|
||||||
|
|
||||||
@@ -457,19 +456,7 @@ class EosTokenCriteria(StoppingCriteria):
|
|||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
||||||
if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id)
|
||||||
# TODO: remove this workaround when we stop supporting torch<=2.3
|
|
||||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
|
||||||
is_done = (
|
|
||||||
input_ids[:, -1]
|
|
||||||
.tile(self.eos_token_id.shape[0], 1)
|
|
||||||
.eq(self.eos_token_id.unsqueeze(1))
|
|
||||||
.sum(dim=0)
|
|
||||||
.bool()
|
|
||||||
.squeeze()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
|
|
||||||
return is_done
|
return is_done
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ from ..models.auto import (
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
)
|
)
|
||||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
|
from ..pytorch_utils import isin_mps_friendly
|
||||||
from ..tokenization_utils import ExtensionsTrie
|
from ..tokenization_utils import ExtensionsTrie
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -472,18 +472,11 @@ class GenerationMixin:
|
|||||||
if not is_input_ids:
|
if not is_input_ids:
|
||||||
return default_attention_mask
|
return default_attention_mask
|
||||||
|
|
||||||
# Otherwise we have may have information -> try to infer the attention mask
|
|
||||||
if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
|
||||||
# mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
|
|
||||||
raise ValueError(
|
|
||||||
"Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
|
|
||||||
)
|
|
||||||
|
|
||||||
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
||||||
torch.isin(elements=inputs, test_elements=pad_token_id).any()
|
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
|
||||||
)
|
)
|
||||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
||||||
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
|
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||||
)
|
)
|
||||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||||
@@ -1660,7 +1653,7 @@ class GenerationMixin:
|
|||||||
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
|
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
|
||||||
if (
|
if (
|
||||||
eos_token_tensor is not None
|
eos_token_tensor is not None
|
||||||
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||||
):
|
):
|
||||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from ...modeling_outputs import (
|
|||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
from ...modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
from ...pytorch_utils import Conv1D
|
from ...pytorch_utils import Conv1D, isin_mps_friendly
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -132,7 +132,7 @@ def _pad_extra_bos_eos_tokens(
|
|||||||
)
|
)
|
||||||
for i, each_input_id in enumerate(input_ids):
|
for i, each_input_id in enumerate(input_ids):
|
||||||
# locate where the valid tokens end and then add the eos token
|
# locate where the valid tokens end and then add the eos token
|
||||||
if torch.isin(each_input_id, pad_token_id).sum():
|
if isin_mps_friendly(each_input_id, pad_token_id).sum():
|
||||||
pos = torch.where(each_input_id == pad_token_id)[0].min()
|
pos = torch.where(each_input_id == pad_token_id)[0].min()
|
||||||
modified_input_ids[i] = torch.concatenate(
|
modified_input_ids[i] = torch.concatenate(
|
||||||
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
|
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
|
||||||
|
|||||||
@@ -303,3 +303,24 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
|
|||||||
unique_id = storage_ptr(tensor)
|
unique_id = storage_ptr(tensor)
|
||||||
|
|
||||||
return tensor.device, unique_id, storage_size(tensor)
|
return tensor.device, unique_id, storage_size(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting
|
||||||
|
torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||||
|
|
||||||
|
Args:
|
||||||
|
elements (`torch.Tensor`): Input elements
|
||||||
|
test_elements (`torch.Tensor`): The elements to check against.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
|
||||||
|
and False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||||
|
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
|
||||||
|
else:
|
||||||
|
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
|
||||||
|
return torch.isin(elements, test_elements)
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ if is_torch_available():
|
|||||||
dtype_byte_size,
|
dtype_byte_size,
|
||||||
shard_checkpoint,
|
shard_checkpoint,
|
||||||
)
|
)
|
||||||
|
from transformers.pytorch_utils import isin_mps_friendly
|
||||||
|
|
||||||
# Fake pretrained models for tests
|
# Fake pretrained models for tests
|
||||||
class BaseModel(PreTrainedModel):
|
class BaseModel(PreTrainedModel):
|
||||||
@@ -1698,6 +1699,22 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertIn("beta_param", missing_keys)
|
self.assertIn("beta_param", missing_keys)
|
||||||
self.assertIn("bias_param", unexpected_keys)
|
self.assertIn("bias_param", unexpected_keys)
|
||||||
|
|
||||||
|
def test_isin_mps_friendly(self):
|
||||||
|
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
|
||||||
|
random_ids = torch.randint(0, 100, (100,))
|
||||||
|
# We can match against an interger
|
||||||
|
random_test_integer = torch.randint(0, 100, (1,)).item()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We can match against an tensor of integers
|
||||||
|
random_test_tensor = torch.randint(0, 100, (10,))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user