Fix decoding score comparison when using logits processors or warpers (#10638)
* Normalize using a logits warper * Add a flag in `generate` to support the logit renormalization * Add in RAG
This commit is contained in:
@@ -33,6 +33,7 @@ if is_torch_available():
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
||||
).all()
|
||||
)
|
||||
|
||||
def test_normalization(self):
|
||||
input_ids = None
|
||||
|
||||
scores = torch.tensor(
|
||||
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
|
||||
logit_normalization = LogitNormalization()
|
||||
normalized_scores = logit_normalization(input_ids, scores).exp()
|
||||
|
||||
ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
|
||||
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||
|
||||
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||
|
||||
Reference in New Issue
Block a user