🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)

* init PR

* optimize top p and add edge case

* styling

* style

* revert tf and flax test

* add edge case test for FLAX and TF

* update doc with smallest set sampling for top p

* make style
This commit is contained in:
Ekagra Ranjan
2022-09-15 19:20:11 +05:30
committed by GitHub
parent 2700ba66d9
commit 578e18e002
7 changed files with 21 additions and 21 deletions

View File

@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase):
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7)
top_p_warp = FlaxTopPLogitsWarper(0.8)
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))