From 29bdb88368e319acb0cbe145021427d37a3507c9 Mon Sep 17 00:00:00 2001 From: Binoy Dalal Date: Fri, 20 Nov 2020 13:59:06 -0500 Subject: [PATCH] Vectorize RepetitionPenaltyLogitsProcessor to improve performance (#8598) * refactored exisiting nested loops to vectorized implementation * replaced explicit indexing with torch.where * modifying score for previous input_ids only --- src/transformers/generation_logits_process.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index dc6b183c4f..0a841e8955 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): self.penalty = penalty def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - for i in range(scores.shape[0]): - for previous_token in set(input_ids[i].tolist()): - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - if scores[i, previous_token] < 0: - scores[i, previous_token] *= self.penalty - else: - scores[i, previous_token] /= self.penalty + ranges = torch.arange(scores.shape[0]) + score = scores[ranges[:, None], input_ids] + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + + scores[ranges[:, None], input_ids] = score return scores