🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)
* init PR * optimize top p and add edge case * styling * style * revert tf and flax test * add edge case test for FLAX and TF * update doc with smallest set sampling for top p * make style
This commit is contained in:
@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
||||
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
@@ -169,10 +169,10 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
top_p_warp = TopPLogitsWarper(0.7)
|
||||
top_p_warp = TopPLogitsWarper(0.8)
|
||||
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
|
||||
|
||||
@@ -189,12 +189,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
|
||||
|
||||
top_p_warp = TFTopPLogitsWarper(0.7)
|
||||
# top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob
|
||||
# However, due to the numerical instability of softmax in TF we choose this as the edge case
|
||||
# top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984.
|
||||
top_p_warp = TFTopPLogitsWarper(0.79999995)
|
||||
if use_xla:
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
|
||||
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user