🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)
* init PR * optimize top p and add edge case * styling * style * revert tf and flax test * add edge case test for FLAX and TF * update doc with smallest set sampling for top p * make style
This commit is contained in:
@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
||||
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user