diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index e462e483c2..9fc57b11d7 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -1272,6 +1272,11 @@ class ContinuousBatchingManager: @traced(span_name="logit_processing") def _process_logit(self, batch_data, logits): + # Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner! + if hasattr(self.logit_processor, "set_continuous_batching_context"): + self.logit_processor.set_continuous_batching_context( + batch_data["logits_indices"], batch_data["cumulative_seqlens_q"] + ) return self.logit_processor(batch_data["input_ids"], logits) @traced(span_name="sampling") diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d4c08e270b..d001fe1a94 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -355,17 +355,54 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): self.penalty = penalty self.prompt_ignore_length = prompt_ignore_length + self.logits_indices = None + self.cumulative_seqlens_q = None + + def set_continuous_batching_context(self, logits_indices: torch.Tensor, cumulative_seqlens_q: torch.Tensor): + self.logits_indices = logits_indices + self.cumulative_seqlens_q = cumulative_seqlens_q @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if self.prompt_ignore_length: input_ids = input_ids[:, self.prompt_ignore_length :] - score = torch.gather(scores, 1, input_ids) + if scores.dim() == 3: + if self.logits_indices is not None and self.cumulative_seqlens_q is not None: + batch_size, seq_len, vocab_size = scores.shape + last_positions = self.logits_indices + last_scores = scores[0, last_positions, :] + # Prepare token mask + token_mask = torch.zeros_like(last_scores, dtype=torch.bool) + cu_seq_lens = self.cumulative_seqlens_q + lengths = cu_seq_lens[1:] - cu_seq_lens[:-1] + seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths) + token_mask[seq_indices, input_ids] = True + + # Apply penalty + penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty) + scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores) + else: + batch_size, seq_len, vocab_size = scores.shape + last_scores = scores[:, -1, :] + token_mask = torch.zeros_like(last_scores, dtype=torch.bool) + if input_ids.dim() == 1: + unique_tokens = torch.unique(input_ids) + token_mask.scatter_(1, unique_tokens.unsqueeze(0), True) + else: + token_mask.scatter_(1, input_ids, True) + # if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities + penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty) + scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores) + return scores + + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(1) + + score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores_processed = scores.scatter(1, input_ids, score) return scores_processed @@ -963,12 +1000,12 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): >>> output = model.generate(**inputs) >>> print(tokenizer.decode(output[0], skip_special_tokens=True)) - Today I’m not sure if I’m going to be able to do it. + Today I'm not sure if I'm going to be able to do it. - >>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I’m") in the output. + >>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I'm") in the output. >>> output = model.generate(**inputs, no_repeat_ngram_size=2) >>> print(tokenizer.decode(output[0], skip_special_tokens=True)) - Today I’m not sure if I can get a better understanding of the nature of this issue + Today I'm not sure if I can get a better understanding of the nature of this issue ``` """ diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 834c502b1a..df68f9c621 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -286,6 +286,39 @@ class LogitsProcessorTest(unittest.TestCase): # processor should not change logits in-place self.assertFalse(torch.all(scores == processed_scores)) + def test_repetition_penalty_continuous_batching(self): + vocab_size = 10 + + input_ids = torch.tensor([1, 2, 3, 4, 5, 6], device=torch_device, dtype=torch.long) + scores = torch.ones((1, 6, vocab_size), device=torch_device, dtype=torch.float) / vocab_size + + scores[0, 2, 1] = -2.0 + scores[0, 2, 2] = 3.0 + scores[0, 2, 3] = 4.0 + scores[0, 5, 4] = -5.0 + scores[0, 5, 5] = 6.0 + scores[0, 5, 6] = 7.0 + + logits_indices = torch.tensor([2, 5], device=torch_device, dtype=torch.long) + cumulative_seqlens_q = torch.tensor([0, 3, 6], device=torch_device, dtype=torch.long) + + rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) + rep_penalty_proc.set_continuous_batching_context(logits_indices, cumulative_seqlens_q) + + original_scores = scores.clone() + processed_scores = rep_penalty_proc(input_ids, scores) + + self.assertAlmostEqual(processed_scores[0, 2, 1].item(), -2.0 * 2.0) + self.assertAlmostEqual(processed_scores[0, 2, 2].item(), 3.0 / 2.0) + self.assertAlmostEqual(processed_scores[0, 2, 3].item(), 4.0 / 2.0) + self.assertAlmostEqual(processed_scores[0, 5, 4].item(), -5.0 * 2.0) + self.assertAlmostEqual(processed_scores[0, 5, 5].item(), 6.0 / 2.0) + self.assertAlmostEqual(processed_scores[0, 5, 6].item(), 7.0 / 2.0) + self.assertAlmostEqual(processed_scores[0, 2, 0].item(), 1.0 / vocab_size) + self.assertAlmostEqual(processed_scores[0, 5, 0].item(), 1.0 / vocab_size) + + self.assertFalse(torch.all(original_scores == processed_scores)) + def test_top_k_dist_warper(self): input_ids = None vocab_size = 10