[Paged-Attention] Handle continuous batching for repetition penalty (#39457)
* Handle continuous batching for repetition penalty * fix last scores and with token mask creation * add test * Update src/transformers/generation/continuous_batching.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix formatting * remove unneeded cast --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1272,6 +1272,11 @@ class ContinuousBatchingManager:
|
||||
|
||||
@traced(span_name="logit_processing")
|
||||
def _process_logit(self, batch_data, logits):
|
||||
# Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
|
||||
if hasattr(self.logit_processor, "set_continuous_batching_context"):
|
||||
self.logit_processor.set_continuous_batching_context(
|
||||
batch_data["logits_indices"], batch_data["cumulative_seqlens_q"]
|
||||
)
|
||||
return self.logit_processor(batch_data["input_ids"], logits)
|
||||
|
||||
@traced(span_name="sampling")
|
||||
|
||||
@@ -355,17 +355,54 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
|
||||
self.penalty = penalty
|
||||
self.prompt_ignore_length = prompt_ignore_length
|
||||
self.logits_indices = None
|
||||
self.cumulative_seqlens_q = None
|
||||
|
||||
def set_continuous_batching_context(self, logits_indices: torch.Tensor, cumulative_seqlens_q: torch.Tensor):
|
||||
self.logits_indices = logits_indices
|
||||
self.cumulative_seqlens_q = cumulative_seqlens_q
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.prompt_ignore_length:
|
||||
input_ids = input_ids[:, self.prompt_ignore_length :]
|
||||
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
if scores.dim() == 3:
|
||||
if self.logits_indices is not None and self.cumulative_seqlens_q is not None:
|
||||
batch_size, seq_len, vocab_size = scores.shape
|
||||
last_positions = self.logits_indices
|
||||
last_scores = scores[0, last_positions, :]
|
||||
|
||||
# Prepare token mask
|
||||
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
|
||||
cu_seq_lens = self.cumulative_seqlens_q
|
||||
lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
|
||||
seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
|
||||
token_mask[seq_indices, input_ids] = True
|
||||
|
||||
# Apply penalty
|
||||
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
|
||||
scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores)
|
||||
else:
|
||||
batch_size, seq_len, vocab_size = scores.shape
|
||||
last_scores = scores[:, -1, :]
|
||||
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
|
||||
if input_ids.dim() == 1:
|
||||
unique_tokens = torch.unique(input_ids)
|
||||
token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
|
||||
else:
|
||||
token_mask.scatter_(1, input_ids, True)
|
||||
# if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
|
||||
scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores)
|
||||
return scores
|
||||
|
||||
if input_ids.dim() == 1:
|
||||
input_ids = input_ids.unsqueeze(1)
|
||||
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores_processed = scores.scatter(1, input_ids, score)
|
||||
return scores_processed
|
||||
|
||||
@@ -963,12 +1000,12 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
|
||||
>>> output = model.generate(**inputs)
|
||||
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
Today I’m not sure if I’m going to be able to do it.
|
||||
Today I'm not sure if I'm going to be able to do it.
|
||||
|
||||
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I’m") in the output.
|
||||
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I'm") in the output.
|
||||
>>> output = model.generate(**inputs, no_repeat_ngram_size=2)
|
||||
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
Today I’m not sure if I can get a better understanding of the nature of this issue
|
||||
Today I'm not sure if I can get a better understanding of the nature of this issue
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -286,6 +286,39 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(scores == processed_scores))
|
||||
|
||||
def test_repetition_penalty_continuous_batching(self):
|
||||
vocab_size = 10
|
||||
|
||||
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], device=torch_device, dtype=torch.long)
|
||||
scores = torch.ones((1, 6, vocab_size), device=torch_device, dtype=torch.float) / vocab_size
|
||||
|
||||
scores[0, 2, 1] = -2.0
|
||||
scores[0, 2, 2] = 3.0
|
||||
scores[0, 2, 3] = 4.0
|
||||
scores[0, 5, 4] = -5.0
|
||||
scores[0, 5, 5] = 6.0
|
||||
scores[0, 5, 6] = 7.0
|
||||
|
||||
logits_indices = torch.tensor([2, 5], device=torch_device, dtype=torch.long)
|
||||
cumulative_seqlens_q = torch.tensor([0, 3, 6], device=torch_device, dtype=torch.long)
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
rep_penalty_proc.set_continuous_batching_context(logits_indices, cumulative_seqlens_q)
|
||||
|
||||
original_scores = scores.clone()
|
||||
processed_scores = rep_penalty_proc(input_ids, scores)
|
||||
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 1].item(), -2.0 * 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 2].item(), 3.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 3].item(), 4.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 4].item(), -5.0 * 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 5].item(), 6.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 6].item(), 7.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 0].item(), 1.0 / vocab_size)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 0].item(), 1.0 / vocab_size)
|
||||
|
||||
self.assertFalse(torch.all(original_scores == processed_scores))
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
|
||||
Reference in New Issue
Block a user