[Paged-Attention] Handle continuous batching for repetition penalty (#39457)
* Handle continuous batching for repetition penalty * fix last scores and with token mask creation * add test * Update src/transformers/generation/continuous_batching.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/logits_process.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix formatting * remove unneeded cast --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -286,6 +286,39 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(scores == processed_scores))
|
||||
|
||||
def test_repetition_penalty_continuous_batching(self):
|
||||
vocab_size = 10
|
||||
|
||||
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], device=torch_device, dtype=torch.long)
|
||||
scores = torch.ones((1, 6, vocab_size), device=torch_device, dtype=torch.float) / vocab_size
|
||||
|
||||
scores[0, 2, 1] = -2.0
|
||||
scores[0, 2, 2] = 3.0
|
||||
scores[0, 2, 3] = 4.0
|
||||
scores[0, 5, 4] = -5.0
|
||||
scores[0, 5, 5] = 6.0
|
||||
scores[0, 5, 6] = 7.0
|
||||
|
||||
logits_indices = torch.tensor([2, 5], device=torch_device, dtype=torch.long)
|
||||
cumulative_seqlens_q = torch.tensor([0, 3, 6], device=torch_device, dtype=torch.long)
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
rep_penalty_proc.set_continuous_batching_context(logits_indices, cumulative_seqlens_q)
|
||||
|
||||
original_scores = scores.clone()
|
||||
processed_scores = rep_penalty_proc(input_ids, scores)
|
||||
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 1].item(), -2.0 * 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 2].item(), 3.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 3].item(), 4.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 4].item(), -5.0 * 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 5].item(), 6.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 6].item(), 7.0 / 2.0)
|
||||
self.assertAlmostEqual(processed_scores[0, 2, 0].item(), 1.0 / vocab_size)
|
||||
self.assertAlmostEqual(processed_scores[0, 5, 0].item(), 1.0 / vocab_size)
|
||||
|
||||
self.assertFalse(torch.all(original_scores == processed_scores))
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
|
||||
Reference in New Issue
Block a user