[Paged-Attention] Handle continuous batching for repetition penalty (#39457)

* Handle continuous batching for repetition penalty

* fix last scores and with token mask creation

* add test

* Update src/transformers/generation/continuous_batching.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix formatting

* remove unneeded cast

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Kashif Rasul
2025-07-22 18:13:40 +02:00
committed by GitHub
parent cbcb8e6c1f
commit 2936902a76
3 changed files with 80 additions and 5 deletions

View File

@@ -1272,6 +1272,11 @@ class ContinuousBatchingManager:
@traced(span_name="logit_processing")
def _process_logit(self, batch_data, logits):
# Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
if hasattr(self.logit_processor, "set_continuous_batching_context"):
self.logit_processor.set_continuous_batching_context(
batch_data["logits_indices"], batch_data["cumulative_seqlens_q"]
)
return self.logit_processor(batch_data["input_ids"], logits)
@traced(span_name="sampling")

View File

@@ -355,17 +355,54 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty
self.prompt_ignore_length = prompt_ignore_length
self.logits_indices = None
self.cumulative_seqlens_q = None
def set_continuous_batching_context(self, logits_indices: torch.Tensor, cumulative_seqlens_q: torch.Tensor):
self.logits_indices = logits_indices
self.cumulative_seqlens_q = cumulative_seqlens_q
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.prompt_ignore_length:
input_ids = input_ids[:, self.prompt_ignore_length :]
score = torch.gather(scores, 1, input_ids)
if scores.dim() == 3:
if self.logits_indices is not None and self.cumulative_seqlens_q is not None:
batch_size, seq_len, vocab_size = scores.shape
last_positions = self.logits_indices
last_scores = scores[0, last_positions, :]
# Prepare token mask
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
cu_seq_lens = self.cumulative_seqlens_q
lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
token_mask[seq_indices, input_ids] = True
# Apply penalty
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores)
else:
batch_size, seq_len, vocab_size = scores.shape
last_scores = scores[:, -1, :]
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
if input_ids.dim() == 1:
unique_tokens = torch.unique(input_ids)
token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
else:
token_mask.scatter_(1, input_ids, True)
# if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores)
return scores
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(1)
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_processed = scores.scatter(1, input_ids, score)
return scores_processed
@@ -963,12 +1000,12 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
>>> output = model.generate(**inputs)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
Today Im not sure if Im going to be able to do it.
Today I'm not sure if I'm going to be able to do it.
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("Im") in the output.
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I'm") in the output.
>>> output = model.generate(**inputs, no_repeat_ngram_size=2)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
Today Im not sure if I can get a better understanding of the nature of this issue
Today I'm not sure if I can get a better understanding of the nature of this issue
```
"""

View File

@@ -286,6 +286,39 @@ class LogitsProcessorTest(unittest.TestCase):
# processor should not change logits in-place
self.assertFalse(torch.all(scores == processed_scores))
def test_repetition_penalty_continuous_batching(self):
vocab_size = 10
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], device=torch_device, dtype=torch.long)
scores = torch.ones((1, 6, vocab_size), device=torch_device, dtype=torch.float) / vocab_size
scores[0, 2, 1] = -2.0
scores[0, 2, 2] = 3.0
scores[0, 2, 3] = 4.0
scores[0, 5, 4] = -5.0
scores[0, 5, 5] = 6.0
scores[0, 5, 6] = 7.0
logits_indices = torch.tensor([2, 5], device=torch_device, dtype=torch.long)
cumulative_seqlens_q = torch.tensor([0, 3, 6], device=torch_device, dtype=torch.long)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
rep_penalty_proc.set_continuous_batching_context(logits_indices, cumulative_seqlens_q)
original_scores = scores.clone()
processed_scores = rep_penalty_proc(input_ids, scores)
self.assertAlmostEqual(processed_scores[0, 2, 1].item(), -2.0 * 2.0)
self.assertAlmostEqual(processed_scores[0, 2, 2].item(), 3.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 2, 3].item(), 4.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 5, 4].item(), -5.0 * 2.0)
self.assertAlmostEqual(processed_scores[0, 5, 5].item(), 6.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 5, 6].item(), 7.0 / 2.0)
self.assertAlmostEqual(processed_scores[0, 2, 0].item(), 1.0 / vocab_size)
self.assertAlmostEqual(processed_scores[0, 5, 0].item(), 1.0 / vocab_size)
self.assertFalse(torch.all(original_scores == processed_scores))
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10