Vectorize RepetitionPenaltyLogitsProcessor to improve performance (#8598)

* refactored exisiting nested loops to vectorized implementation

* replaced explicit indexing with torch.where

* modifying score for previous input_ids only
This commit is contained in:
Binoy Dalal
2020-11-20 13:59:06 -05:00
committed by GitHub
parent 2594bd8b73
commit 29bdb88368

View File

@@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]): ranges = torch.arange(scores.shape[0])
for previous_token in set(input_ids[i].tolist()): score = scores[ranges[:, None], input_ids]
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0: # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
scores[i, previous_token] *= self.penalty score = torch.where(score < 0, score * self.penalty, score / self.penalty)
else:
scores[i, previous_token] /= self.penalty scores[ranges[:, None], input_ids] = score
return scores return scores