[Paged-Attention] Handle continuous batching for repetition penalty (#39457)
* Handle continuous batching for repetition penalty * fix last scores and with token mask creation * add test * Update src/transformers/generation/continuous_batching.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix formatting * remove unneeded cast --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1272,6 +1272,11 @@ class ContinuousBatchingManager:
|
|||||||
|
|
||||||
@traced(span_name="logit_processing")
|
@traced(span_name="logit_processing")
|
||||||
def _process_logit(self, batch_data, logits):
|
def _process_logit(self, batch_data, logits):
|
||||||
|
# Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
|
||||||
|
if hasattr(self.logit_processor, "set_continuous_batching_context"):
|
||||||
|
self.logit_processor.set_continuous_batching_context(
|
||||||
|
batch_data["logits_indices"], batch_data["cumulative_seqlens_q"]
|
||||||
|
)
|
||||||
return self.logit_processor(batch_data["input_ids"], logits)
|
return self.logit_processor(batch_data["input_ids"], logits)
|
||||||
|
|
||||||
@traced(span_name="sampling")
|
@traced(span_name="sampling")
|
||||||
|
|||||||
@@ -355,17 +355,54 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
self.prompt_ignore_length = prompt_ignore_length
|
self.prompt_ignore_length = prompt_ignore_length
|
||||||
|
self.logits_indices = None
|
||||||
|
self.cumulative_seqlens_q = None
|
||||||
|
|
||||||
|
def set_continuous_batching_context(self, logits_indices: torch.Tensor, cumulative_seqlens_q: torch.Tensor):
|
||||||
|
self.logits_indices = logits_indices
|
||||||
|
self.cumulative_seqlens_q = cumulative_seqlens_q
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if self.prompt_ignore_length:
|
if self.prompt_ignore_length:
|
||||||
input_ids = input_ids[:, self.prompt_ignore_length :]
|
input_ids = input_ids[:, self.prompt_ignore_length :]
|
||||||
|
|
||||||
score = torch.gather(scores, 1, input_ids)
|
if scores.dim() == 3:
|
||||||
|
if self.logits_indices is not None and self.cumulative_seqlens_q is not None:
|
||||||
|
batch_size, seq_len, vocab_size = scores.shape
|
||||||
|
last_positions = self.logits_indices
|
||||||
|
last_scores = scores[0, last_positions, :]
|
||||||
|
|
||||||
|
# Prepare token mask
|
||||||
|
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
|
||||||
|
cu_seq_lens = self.cumulative_seqlens_q
|
||||||
|
lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
|
||||||
|
seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
|
||||||
|
token_mask[seq_indices, input_ids] = True
|
||||||
|
|
||||||
|
# Apply penalty
|
||||||
|
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
|
||||||
|
scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores)
|
||||||
|
else:
|
||||||
|
batch_size, seq_len, vocab_size = scores.shape
|
||||||
|
last_scores = scores[:, -1, :]
|
||||||
|
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
|
||||||
|
if input_ids.dim() == 1:
|
||||||
|
unique_tokens = torch.unique(input_ids)
|
||||||
|
token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
|
||||||
|
else:
|
||||||
|
token_mask.scatter_(1, input_ids, True)
|
||||||
|
# if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||||
|
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
|
||||||
|
scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
if input_ids.dim() == 1:
|
||||||
|
input_ids = input_ids.unsqueeze(1)
|
||||||
|
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
|
||||||
scores_processed = scores.scatter(1, input_ids, score)
|
scores_processed = scores.scatter(1, input_ids, score)
|
||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
@@ -963,12 +1000,12 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
>>> output = model.generate(**inputs)
|
>>> output = model.generate(**inputs)
|
||||||
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
|
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
Today I’m not sure if I’m going to be able to do it.
|
Today I'm not sure if I'm going to be able to do it.
|
||||||
|
|
||||||
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I’m") in the output.
|
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I'm") in the output.
|
||||||
>>> output = model.generate(**inputs, no_repeat_ngram_size=2)
|
>>> output = model.generate(**inputs, no_repeat_ngram_size=2)
|
||||||
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
|
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
Today I’m not sure if I can get a better understanding of the nature of this issue
|
Today I'm not sure if I can get a better understanding of the nature of this issue
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -286,6 +286,39 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
# processor should not change logits in-place
|
# processor should not change logits in-place
|
||||||
self.assertFalse(torch.all(scores == processed_scores))
|
self.assertFalse(torch.all(scores == processed_scores))
|
||||||
|
|
||||||
|
def test_repetition_penalty_continuous_batching(self):
|
||||||
|
vocab_size = 10
|
||||||
|
|
||||||
|
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], device=torch_device, dtype=torch.long)
|
||||||
|
scores = torch.ones((1, 6, vocab_size), device=torch_device, dtype=torch.float) / vocab_size
|
||||||
|
|
||||||
|
scores[0, 2, 1] = -2.0
|
||||||
|
scores[0, 2, 2] = 3.0
|
||||||
|
scores[0, 2, 3] = 4.0
|
||||||
|
scores[0, 5, 4] = -5.0
|
||||||
|
scores[0, 5, 5] = 6.0
|
||||||
|
scores[0, 5, 6] = 7.0
|
||||||
|
|
||||||
|
logits_indices = torch.tensor([2, 5], device=torch_device, dtype=torch.long)
|
||||||
|
cumulative_seqlens_q = torch.tensor([0, 3, 6], device=torch_device, dtype=torch.long)
|
||||||
|
|
||||||
|
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||||
|
rep_penalty_proc.set_continuous_batching_context(logits_indices, cumulative_seqlens_q)
|
||||||
|
|
||||||
|
original_scores = scores.clone()
|
||||||
|
processed_scores = rep_penalty_proc(input_ids, scores)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 2, 1].item(), -2.0 * 2.0)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 2, 2].item(), 3.0 / 2.0)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 2, 3].item(), 4.0 / 2.0)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 5, 4].item(), -5.0 * 2.0)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 5, 5].item(), 6.0 / 2.0)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 5, 6].item(), 7.0 / 2.0)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 2, 0].item(), 1.0 / vocab_size)
|
||||||
|
self.assertAlmostEqual(processed_scores[0, 5, 0].item(), 1.0 / vocab_size)
|
||||||
|
|
||||||
|
self.assertFalse(torch.all(original_scores == processed_scores))
|
||||||
|
|
||||||
def test_top_k_dist_warper(self):
|
def test_top_k_dist_warper(self):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
|
|||||||
Reference in New Issue
Block a user