Fix MinNewTokensLengthLogitsProcessor when used with a list of eos tokens (#21959)
* Fix MinNewTokensLengthLogitsProcessor when used with a list of eos tokens * fix docs * Empty commit * formatting
This commit is contained in:
@@ -133,19 +133,23 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
The input tokens length.
|
The input tokens length.
|
||||||
min_new_tokens (`int`):
|
min_new_tokens (`int`):
|
||||||
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||||
eos_token_id (`int`):
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
|
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
|
||||||
for arg_name, arg_value in [
|
for arg_name, arg_value in [
|
||||||
("prompt_length_to_skip", prompt_length_to_skip),
|
("prompt_length_to_skip", prompt_length_to_skip),
|
||||||
("min_new_tokens", min_new_tokens),
|
("min_new_tokens", min_new_tokens),
|
||||||
("eos_token_id", eos_token_id),
|
|
||||||
]:
|
]:
|
||||||
if not isinstance(arg_value, int) or arg_value < 0:
|
if not isinstance(arg_value, int) or arg_value < 0:
|
||||||
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
|
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
|
||||||
|
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
|
||||||
|
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
|
||||||
|
|
||||||
self.prompt_length_to_skip = prompt_length_to_skip
|
self.prompt_length_to_skip = prompt_length_to_skip
|
||||||
self.min_new_tokens = min_new_tokens
|
self.min_new_tokens = min_new_tokens
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
@@ -153,7 +157,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||||
if new_tokens_length < self.min_new_tokens:
|
if new_tokens_length < self.min_new_tokens:
|
||||||
scores[:, self.eos_token_id] = -float("inf")
|
for i in self.eos_token_id:
|
||||||
|
scores[:, i] = -float("inf")
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
@@ -76,10 +78,10 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
def test_new_min_length_dist_processor(self):
|
@parameterized.expand([(0,), ([0, 18],)])
|
||||||
|
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]]):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
eos_token_id = 0
|
|
||||||
|
|
||||||
# check that first input is skipped (min new length applying)
|
# check that first input is skipped (min new length applying)
|
||||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||||
@@ -87,9 +89,15 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
|
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expected_eos_scores_before_min_length = batch_size * [-float("inf")]
|
||||||
|
if isinstance(eos_token_id, list):
|
||||||
|
expected_eos_scores_before_min_length *= len(eos_token_id)
|
||||||
|
|
||||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
self.assertListEqual(
|
||||||
|
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
|
||||||
|
)
|
||||||
|
|
||||||
# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
|
# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
|
||||||
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
|
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
|
||||||
@@ -98,19 +106,25 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
|
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
|
||||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
self.assertListEqual(
|
||||||
|
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
|
||||||
|
)
|
||||||
|
|
||||||
# check that min new length is applied at length 6 (because it has only 1 new token)
|
# check that min new length is applied at length 6 (because it has only 1 new token)
|
||||||
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
|
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
|
||||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
self.assertListEqual(
|
||||||
|
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
|
||||||
|
)
|
||||||
|
|
||||||
# check that min new length is applied at length 7 (because it has only 2 new tokens)
|
# check that min new length is applied at length 7 (because it has only 2 new tokens)
|
||||||
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
|
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
|
||||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
self.assertListEqual(
|
||||||
|
scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length
|
||||||
|
)
|
||||||
|
|
||||||
# check that min new length is not applied anymore at length 8
|
# check that min new length is not applied anymore at length 8
|
||||||
input_ids = ids_tensor((batch_size, 8), vocab_size=20)
|
input_ids = ids_tensor((batch_size, 8), vocab_size=20)
|
||||||
|
|||||||
Reference in New Issue
Block a user