Compare commits

...

2 Commits

Author SHA1 Message Date
ArthurZucker
9fe3f585bb v4.40.1
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2024-04-23 17:05:53 -04:00
Pedro Cuenca
f8fec6b0ad Make EosTokenCriteria compatible with mps (#30376) 2024-04-23 17:04:45 -04:00
3 changed files with 14 additions and 3 deletions

View File

@@ -429,7 +429,7 @@ install_requires = [
setup(
name="transformers",
version="4.40.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.40.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.40.0"
__version__ = "4.40.1"
from typing import TYPE_CHECKING

View File

@@ -146,7 +146,18 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]
.tile(self.eos_token_id.shape[0], 1)
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
.sum(dim=0)
.bool()
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done