Add Epsilon- and Eta-Sampling (#21121)
* Add epsilon- and eta-sampling. Add epsilon- and eta-sampling, following the official code from https://github.com/john-hewitt/truncation-sampling and adapting to be more configurable, as required by Huggingface transformers. * Add unit tests for epsilon- and eta-sampling. * Black: fix code formatting. * Fix docstring spacing. * Clean up newlines. * Fix implementation bugs and their associated tests. * Remove epsilon- and eta-sampling parameters from PretrainedConfig. * Clarify and clean up the documentation. * Remove parameters for PretrainedConfig test.
This commit is contained in:
@@ -28,6 +28,8 @@ if is_torch_available():
|
||||
|
||||
from transformers.generation import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
@@ -288,6 +290,80 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||
|
||||
def test_epsilon_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = torch.log(
|
||||
torch.tensor(
|
||||
[[0.87, 0.099, 0.001, 0.03], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
)
|
||||
|
||||
epsilon_warp = EpsilonLogitsWarper(0.1)
|
||||
filtered_dist = torch.exp(epsilon_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to only keep values with proba >= 0.1
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.87, 0, 0, 0], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||
batch_size, 1
|
||||
) - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
epsilon_warp = EpsilonLogitsWarper(5e-2, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = epsilon_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep 3 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||
|
||||
def test_eta_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = torch.log(
|
||||
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
eta_warp = EtaLogitsWarper(0.0625)
|
||||
filtered_dist = torch.exp(eta_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
|
||||
# min(0.0625, 0.1320) is the cutoff for the first row and min(0.0625, 0.1644) is for the second
|
||||
# where H is the entropy function and p is the probability vector.
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.0, 0.1, 0.8, 0.1], [0.0, 0.0, 0.9, 0.0]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||
batch_size, 1
|
||||
) - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = eta_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
|
||||
Reference in New Issue
Block a user