🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)

* init PR

* optimize top p and add edge case

* styling

* style

* revert tf and flax test

* add edge case test for FLAX and TF

* update doc with smallest set sampling for top p

* make style
This commit is contained in:
Ekagra Ranjan
2022-09-15 19:20:11 +05:30
committed by GitHub
parent 2700ba66d9
commit 578e18e002
7 changed files with 21 additions and 21 deletions

View File

@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase):
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7)
top_p_warp = FlaxTopPLogitsWarper(0.8)
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))

View File

@@ -169,10 +169,10 @@ class LogitsProcessorTest(unittest.TestCase):
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
)
top_p_warp = TopPLogitsWarper(0.7)
top_p_warp = TopPLogitsWarper(0.8)
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float

View File

@@ -189,12 +189,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
top_p_warp = TFTopPLogitsWarper(0.7)
# top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob
# However, due to the numerical instability of softmax in TF we choose this as the edge case
# top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984.
top_p_warp = TFTopPLogitsWarper(0.79999995)
if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
# dist should be filtered to keep min num values so that sum is >= 0.7
# dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)