[Paged-Attention] Handle continuous batching for repetition penalty (#39457)

* Handle continuous batching for repetition penalty

* fix last scores and with token mask creation

* add test

* Update src/transformers/generation/continuous_batching.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix formatting

* remove unneeded cast

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Kashif Rasul
2025-07-22 18:13:40 +02:00
committed by GitHub
parent cbcb8e6c1f
commit 2936902a76
3 changed files with 80 additions and 5 deletions

View File

@@ -286,6 +286,39 @@ class LogitsProcessorTest(unittest.TestCase):
# processor should not change logits in-place
self.assertFalse(torch.all(scores == processed_scores))
def test_repetition_penalty_continuous_batching(self):
vocab_size = 10
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], device=torch_device, dtype=torch.long)
scores = torch.ones((1, 6, vocab_size), device=torch_device, dtype=torch.float) / vocab_size
scores[0, 2, 1] = -2.0
scores[0, 2, 2] = 3.0
scores[0, 2, 3] = 4.0
scores[0, 5, 4] = -5.0
scores[0, 5, 5] = 6.0
scores[0, 5, 6] = 7.0
logits_indices = torch.tensor([2, 5], device=torch_device, dtype=torch.long)
cumulative_seqlens_q = torch.tensor([0, 3, 6], device=torch_device, dtype=torch.long)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
rep_penalty_proc.set_continuous_batching_context(logits_indices, cumulative_seqlens_q)
original_scores = scores.clone()
processed_scores = rep_penalty_proc(input_ids, scores)
self.assertAlmostEqual(processed_scores[0, 2, 1].item(), -2.0 * 2.0)
self.assertAlmostEqual(processed_scores[0, 2, 2].item(), 3.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 2, 3].item(), 4.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 5, 4].item(), -5.0 * 2.0)
self.assertAlmostEqual(processed_scores[0, 5, 5].item(), 6.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 5, 6].item(), 7.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 2, 0].item(), 1.0 / vocab_size)
self.assertAlmostEqual(processed_scores[0, 5, 0].item(), 1.0 / vocab_size)
self.assertFalse(torch.all(original_scores == processed_scores))
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10