Vectorize RepetitionPenaltyLogitsProcessor to improve performance (#8598)
* refactored exisiting nested loops to vectorized implementation * replaced explicit indexing with torch.where * modifying score for previous input_ids only
This commit is contained in:
@@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
for i in range(scores.shape[0]):
|
ranges = torch.arange(scores.shape[0])
|
||||||
for previous_token in set(input_ids[i].tolist()):
|
score = scores[ranges[:, None], input_ids]
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
|
||||||
if scores[i, previous_token] < 0:
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
scores[i, previous_token] *= self.penalty
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
else:
|
|
||||||
scores[i, previous_token] /= self.penalty
|
scores[ranges[:, None], input_ids] = score
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user