Add implementation of typical sampling (#15504)

* typical decoding

* changing arg name

* add test config params

* forgotten arg rename

* fix edge case where scores are same

* test for typical logits warper

* code quality fixes
This commit is contained in:
Clara Meister
2022-02-09 16:48:41 +01:00
committed by GitHub
parent f588cf4050
commit 0113aae5b7
5 changed files with 94 additions and 3 deletions

View File

@@ -41,6 +41,7 @@ if is_torch_available():
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
@@ -191,6 +192,51 @@ class LogitsProcessorTest(unittest.TestCase):
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def test_typical_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
)
typical_warp = TypicalLogitsWarper(0.5)
filtered_dist = torch.exp(typical_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)
scores = typical_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = typical_warp(input_ids, ramp_logits)
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2