Fix decoding score comparison when using logits processors or warpers (#10638)

* Normalize using a logits warper

* Add a flag in `generate` to support the logit renormalization

* Add in RAG
This commit is contained in:
Santiago Castro
2022-04-13 04:37:33 -04:00
committed by GitHub
parent eb5bdcdfa5
commit f7196f2e63
4 changed files with 58 additions and 2 deletions

View File

@@ -33,6 +33,7 @@ if is_torch_available():
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase):
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)
def test_normalization(self):
input_ids = None
scores = torch.tensor(
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
)
logit_normalization = LogitNormalization()
normalized_scores = logit_normalization(input_ids, scores).exp()
ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))