Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9fe3f585bb | ||
|
|
f8fec6b0ad |
2
setup.py
2
setup.py
@@ -429,7 +429,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.40.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.40.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.40.0"
|
||||
__version__ = "4.40.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -146,7 +146,18 @@ class EosTokenCriteria(StoppingCriteria):
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||
if input_ids.device.type == "mps":
|
||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||
is_done = (
|
||||
input_ids[:, -1]
|
||||
.tile(self.eos_token_id.shape[0], 1)
|
||||
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
|
||||
.sum(dim=0)
|
||||
.bool()
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||
return is_done
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user