MinNewTokensLengthLogitsProcessor for .generate method #20814 (#20892)

* feat: add min new length logit processor

* test: add min new length logit processor

* docs: add MinNewTokensLengthLogitsProcessor

* feat: import MinNewTokensLengthLogitsProcessor

* fix: update pytorch dummy objects

* refactor & fix: rename attributes and var and get rid of dynamic attribute

* tests: align test with new interface

* docs: fix typo

* docs: minor clarification

* Empty-Commit

* empty commit

* run automated quality edits

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Konstantin Kotik
2023-01-03 14:29:02 +03:00
committed by GitHub
parent 4fd89e4978
commit 367fdf3330
6 changed files with 99 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ if is_torch_available():
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
@@ -72,6 +73,54 @@ class LogitsProcessorTest(unittest.TestCase):
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_new_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
# check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
# check that min length is applied at length 2
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that min new length is applied at length 6 (because it has only 1 new token)
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that min new length is applied at length 7 (because it has only 2 new tokens)
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that min new length is not applied anymore at length 8
input_ids = ids_tensor((batch_size, 8), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
# check that min new length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_temperature_dist_warper(self):
input_ids = None
length = 20