Allow Exclusion of Input IDs from RepetitionPenaltyLogitsProcessor (#37625)
* Allow exclusion of input IDs for repetition penalty * Add logit proc tests for rep penalty exclusion * Expose rep pen flag through generate * Only slice if needed * keep current rep pen default behavior * Revert exposing reppen changes through generate * Fix test arg * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Rename to rep penalty kwarg * Add custom repetition penalty processor example * Validate prompt_ignore_length --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -203,6 +203,56 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(scores == processed_scores))
|
||||
|
||||
def test_repetition_penalty_dist_process_exclusion_no_new_input_ids(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=2.0,
|
||||
prompt_ignore_length=input_ids.shape[-1],
|
||||
)
|
||||
|
||||
processed_scores = rep_penalty_proc(input_ids, scores)
|
||||
|
||||
# Because input IDs were provided & we call with the same input
|
||||
# IDs that we initialize with, it should be the same as calling
|
||||
# with no input IDs, so no scores should be penalized.
|
||||
self.assertTrue(torch.all(scores == processed_scores))
|
||||
|
||||
def test_repetition_penalty_dist_process_exclusion_with_new_input_ids(self):
|
||||
orig_input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
curr_input_ids = torch.tensor([[0, 1, 0, 1], [5, 0, 5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=2.0,
|
||||
prompt_ignore_length=orig_input_ids.shape[-1],
|
||||
)
|
||||
|
||||
processed_scores = rep_penalty_proc(curr_input_ids, scores)
|
||||
|
||||
# check that values were correctly changed
|
||||
self.assertAlmostEqual(processed_scores[0, 0].item(), -(1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(processed_scores[0, 1].item(), (1 / vocab_size) / 2)
|
||||
|
||||
self.assertAlmostEqual(processed_scores[1, 0].item(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(processed_scores[1, 5].item(), (4 / vocab_size) / 2)
|
||||
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(scores == processed_scores))
|
||||
|
||||
def test_encoder_repetition_penalty_dist_process(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
Reference in New Issue
Block a user