* feat: add min new length logit processor * test: add min new length logit processor * docs: add MinNewTokensLengthLogitsProcessor * feat: import MinNewTokensLengthLogitsProcessor * fix: update pytorch dummy objects * refactor & fix: rename attributes and var and get rid of dynamic attribute * tests: align test with new interface * docs: fix typo * docs: minor clarification * Empty-Commit * empty commit * run automated quality edits Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -116,6 +116,9 @@ generation.
|
|||||||
[[autodoc]] MinLengthLogitsProcessor
|
[[autodoc]] MinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] MinNewTokensLengthLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TemperatureLogitsWarper
|
[[autodoc]] TemperatureLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
|||||||
@@ -886,6 +886,7 @@ else:
|
|||||||
"MaxLengthCriteria",
|
"MaxLengthCriteria",
|
||||||
"MaxTimeCriteria",
|
"MaxTimeCriteria",
|
||||||
"MinLengthLogitsProcessor",
|
"MinLengthLogitsProcessor",
|
||||||
|
"MinNewTokensLengthLogitsProcessor",
|
||||||
"NoBadWordsLogitsProcessor",
|
"NoBadWordsLogitsProcessor",
|
||||||
"NoRepeatNGramLogitsProcessor",
|
"NoRepeatNGramLogitsProcessor",
|
||||||
"PhrasalConstraint",
|
"PhrasalConstraint",
|
||||||
@@ -4140,6 +4141,7 @@ if TYPE_CHECKING:
|
|||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
|
MinNewTokensLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
PhrasalConstraint,
|
PhrasalConstraint,
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ else:
|
|||||||
"LogitsProcessorList",
|
"LogitsProcessorList",
|
||||||
"LogitsWarper",
|
"LogitsWarper",
|
||||||
"MinLengthLogitsProcessor",
|
"MinLengthLogitsProcessor",
|
||||||
|
"MinNewTokensLengthLogitsProcessor",
|
||||||
"NoBadWordsLogitsProcessor",
|
"NoBadWordsLogitsProcessor",
|
||||||
"NoRepeatNGramLogitsProcessor",
|
"NoRepeatNGramLogitsProcessor",
|
||||||
"PrefixConstrainedLogitsProcessor",
|
"PrefixConstrainedLogitsProcessor",
|
||||||
@@ -171,6 +172,7 @@ if TYPE_CHECKING:
|
|||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
LogitsWarper,
|
LogitsWarper,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
|
MinNewTokensLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
|
|||||||
@@ -121,6 +121,42 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_length_to_skip (`int`):
|
||||||
|
The input tokens length.
|
||||||
|
min_new_tokens (`int`):
|
||||||
|
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
|
||||||
|
eos_token_id (`int`):
|
||||||
|
The id of the *end-of-sequence* token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
|
||||||
|
|
||||||
|
for arg_name, arg_value in [
|
||||||
|
("prompt_length_to_skip", prompt_length_to_skip),
|
||||||
|
("min_new_tokens", min_new_tokens),
|
||||||
|
("eos_token_id", eos_token_id),
|
||||||
|
]:
|
||||||
|
if not isinstance(arg_value, int) or arg_value < 0:
|
||||||
|
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
|
||||||
|
|
||||||
|
self.prompt_length_to_skip = prompt_length_to_skip
|
||||||
|
self.min_new_tokens = min_new_tokens
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||||
|
if new_tokens_length < self.min_new_tokens:
|
||||||
|
scores[:, self.eos_token_id] = -float("inf")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class TemperatureLogitsWarper(LogitsWarper):
|
class TemperatureLogitsWarper(LogitsWarper):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
|
|||||||
@@ -199,6 +199,13 @@ class MinLengthLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MinNewTokensLengthLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class NoBadWordsLogitsProcessor(metaclass=DummyObject):
|
class NoBadWordsLogitsProcessor(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
|||||||
LogitNormalization,
|
LogitNormalization,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
|
MinNewTokensLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
NoRepeatNGramLogitsProcessor,
|
NoRepeatNGramLogitsProcessor,
|
||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
@@ -72,6 +73,54 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
|
def test_new_min_length_dist_processor(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
eos_token_id = 0
|
||||||
|
|
||||||
|
# check that first input is skipped (min new length applying)
|
||||||
|
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||||
|
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
|
||||||
|
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
|
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
||||||
|
|
||||||
|
# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
|
||||||
|
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
|
||||||
|
|
||||||
|
# check that min length is applied at length 2
|
||||||
|
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
|
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
||||||
|
|
||||||
|
# check that min new length is applied at length 6 (because it has only 1 new token)
|
||||||
|
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
|
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
||||||
|
|
||||||
|
# check that min new length is applied at length 7 (because it has only 2 new tokens)
|
||||||
|
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
|
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
|
||||||
|
|
||||||
|
# check that min new length is not applied anymore at length 8
|
||||||
|
input_ids = ids_tensor((batch_size, 8), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
|
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
|
# check that min new length is not applied anymore at length 15
|
||||||
|
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = new_min_dist_processor(input_ids, scores)
|
||||||
|
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
def test_temperature_dist_warper(self):
|
def test_temperature_dist_warper(self):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
length = 20
|
length = 20
|
||||||
|
|||||||
Reference in New Issue
Block a user