From 367fdf3330121a075c06d796bb95dfb1c69c65e4 Mon Sep 17 00:00:00 2001 From: Konstantin Kotik <22777646+kotikkonstantin@users.noreply.github.com> Date: Tue, 3 Jan 2023 14:29:02 +0300 Subject: [PATCH] `MinNewTokensLengthLogitsProcessor` for `.generate` method #20814 (#20892) * feat: add min new length logit processor * test: add min new length logit processor * docs: add MinNewTokensLengthLogitsProcessor * feat: import MinNewTokensLengthLogitsProcessor * fix: update pytorch dummy objects * refactor & fix: rename attributes and var and get rid of dynamic attribute * tests: align test with new interface * docs: fix typo * docs: minor clarification * Empty-Commit * empty commit * run automated quality edits Co-authored-by: Joao Gante --- docs/source/en/internal/generation_utils.mdx | 3 ++ src/transformers/__init__.py | 2 + src/transformers/generation/__init__.py | 2 + src/transformers/generation/logits_process.py | 36 ++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 7 +++ tests/generation/test_logits_process.py | 49 +++++++++++++++++++ 6 files changed, 99 insertions(+) diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index b1174f2008..3c86b7dc3f 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -116,6 +116,9 @@ generation. [[autodoc]] MinLengthLogitsProcessor - __call__ +[[autodoc]] MinNewTokensLengthLogitsProcessor + - __call__ + [[autodoc]] TemperatureLogitsWarper - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bf53737e96..88c23302ae 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -886,6 +886,7 @@ else: "MaxLengthCriteria", "MaxTimeCriteria", "MinLengthLogitsProcessor", + "MinNewTokensLengthLogitsProcessor", "NoBadWordsLogitsProcessor", "NoRepeatNGramLogitsProcessor", "PhrasalConstraint", @@ -4140,6 +4141,7 @@ if TYPE_CHECKING: MaxLengthCriteria, MaxTimeCriteria, MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PhrasalConstraint, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index b1d8e8acad..d0c3a32973 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -51,6 +51,7 @@ else: "LogitsProcessorList", "LogitsWarper", "MinLengthLogitsProcessor", + "MinNewTokensLengthLogitsProcessor", "NoBadWordsLogitsProcessor", "NoRepeatNGramLogitsProcessor", "PrefixConstrainedLogitsProcessor", @@ -171,6 +172,7 @@ if TYPE_CHECKING: LogitsProcessorList, LogitsWarper, MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5f54008e16..09d88d89fe 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -121,6 +121,42 @@ class MinLengthLogitsProcessor(LogitsProcessor): return scores +class MinNewTokensLengthLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. + + Args: + prompt_length_to_skip (`int`): + The input tokens length. + min_new_tokens (`int`): + The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + """ + + def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int): + + for arg_name, arg_value in [ + ("prompt_length_to_skip", prompt_length_to_skip), + ("min_new_tokens", min_new_tokens), + ("eos_token_id", eos_token_id), + ]: + if not isinstance(arg_value, int) or arg_value < 0: + raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}") + + self.prompt_length_to_skip = prompt_length_to_skip + self.min_new_tokens = min_new_tokens + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + + new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip + if new_tokens_length < self.min_new_tokens: + scores[:, self.eos_token_id] = -float("inf") + + return scores + + class TemperatureLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution). diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 178a0b5ae6..fe46e68934 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -199,6 +199,13 @@ class MinLengthLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MinNewTokensLengthLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class NoBadWordsLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index ec4a9f586e..5a47884f4a 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -36,6 +36,7 @@ if is_torch_available(): LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, @@ -72,6 +73,54 @@ class LogitsProcessorTest(unittest.TestCase): scores_before_min_length = min_dist_processor(input_ids, scores) self.assertFalse(torch.isinf(scores_before_min_length).any()) + def test_new_min_length_dist_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + + # check that first input is skipped (min new length applying) + input_ids = ids_tensor((batch_size, 5), vocab_size=20) + new_min_dist_processor = MinNewTokensLengthLogitsProcessor( + prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id + ) + + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = new_min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + + # check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped + self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5) + + # check that min length is applied at length 2 + input_ids = ids_tensor((batch_size, 2), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = new_min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + + # check that min new length is applied at length 6 (because it has only 1 new token) + input_ids = ids_tensor((batch_size, 6), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = new_min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + + # check that min new length is applied at length 7 (because it has only 2 new tokens) + input_ids = ids_tensor((batch_size, 7), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = new_min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")]) + + # check that min new length is not applied anymore at length 8 + input_ids = ids_tensor((batch_size, 8), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = new_min_dist_processor(input_ids, scores) + self.assertFalse(torch.isinf(scores_before_min_length).any()) + + # check that min new length is not applied anymore at length 15 + input_ids = ids_tensor((batch_size, 15), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = new_min_dist_processor(input_ids, scores) + self.assertFalse(torch.isinf(scores_before_min_length).any()) + def test_temperature_dist_warper(self): input_ids = None length = 20