From a98173cc45727f619f4c99b60448584385caf30a Mon Sep 17 00:00:00 2001 From: LSinev Date: Wed, 20 Jan 2021 12:23:01 +0300 Subject: [PATCH] make RepetitionPenaltyLogitsProcessor faster (#9600) --- src/transformers/generation_logits_process.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 166cd5aa8c..a027eacbde 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -155,13 +155,12 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): self.penalty = penalty def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - ranges = torch.arange(scores.shape[0]) - score = scores[ranges[:, None], input_ids] + score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores[ranges[:, None], input_ids] = score + scores.scatter_(1, input_ids, score) return scores