Merge pull request #2303 from patrickvonplaten/fix_error_with_repetition_penalty
fix repetition penalty error in modeling_utils.py
This commit is contained in:
@@ -728,7 +728,11 @@ class PreTrainedModel(nn.Module):
|
|||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for previous_tokens in set(input_ids[i].tolist()):
|
for previous_tokens in set(input_ids[i].tolist()):
|
||||||
next_token_logits[i, previous_tokens] /= repetition_penalty
|
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
|
if next_token_logits[i, previous_tokens] < 0:
|
||||||
|
next_token_logits[i, previous_tokens] *= repetition_penalty
|
||||||
|
else:
|
||||||
|
next_token_logits[i, previous_tokens] /= repetition_penalty
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
@@ -807,7 +811,11 @@ class PreTrainedModel(nn.Module):
|
|||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
for i in range(batch_size * num_beams):
|
for i in range(batch_size * num_beams):
|
||||||
for previous_tokens in set(input_ids[i].tolist()):
|
for previous_tokens in set(input_ids[i].tolist()):
|
||||||
scores[i, previous_tokens] /= repetition_penalty
|
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
|
if scores[i, previous_tokens] < 0:
|
||||||
|
scores[i, previous_tokens] *= repetition_penalty
|
||||||
|
else:
|
||||||
|
scores[i, previous_tokens] /= repetition_penalty
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user