make RepetitionPenaltyLogitsProcessor faster (#9600)
This commit is contained in:
@@ -155,13 +155,12 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
self.penalty = penalty
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
ranges = torch.arange(scores.shape[0])
|
||||
score = scores[ranges[:, None], input_ids]
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores[ranges[:, None], input_ids] = score
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user