Generate: add min_p sampling (#30639)
* min_p * more relaxed test to avoid numerical issues * Update src/transformers/generation/logits_process.py Co-authored-by: menhguin <minh1228@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: menhguin <minh1228@gmail.com> * docstring clarifications * PR comments * Update tests/generation/test_logits_process.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make fixup --------- Co-authored-by: menhguin <minh1228@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -42,6 +42,7 @@ if is_torch_available():
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
MinNewTokensLengthLogitsProcessor,
|
||||
MinPLogitsWarper,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
@@ -304,6 +305,52 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||
|
||||
def test_min_p_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in MinPLogitsWarper)
|
||||
dist = torch.log(
|
||||
torch.tensor(
|
||||
[
|
||||
[0.9, 0.0274, 0.047, 0.0274], # two tokens should be kept (0.047 > 0.9*0.05=0.045)
|
||||
[0.15, 0.3, 0.3, 0.25], # all should be kept -- no high-probability token
|
||||
[0.97, 0.01, 0.01, 0.01], # only the first token should be kept
|
||||
],
|
||||
device=torch_device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
)
|
||||
|
||||
min_p_warp = MinPLogitsWarper(0.05)
|
||||
filtered_dist = torch.exp(min_p_warp(input_ids, dist))
|
||||
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.9, 0.0, 0.047, 0.0], [0.15, 0.3, 0.3, 0.25], [0.97, 0.0, 0.0, 0.0]],
|
||||
device=torch_device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# processor should not change logits in-place
|
||||
self.assertFalse(torch.all(min_p_warp(input_ids, dist) == dist))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float) - (vocab_size // 2)
|
||||
ramp_logits = ramp_logits.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
min_p_warp = MinPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = min_p_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||
|
||||
def test_typical_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
|
||||
Reference in New Issue
Block a user