From 578e18e0028c26d8988413ef270da1b292e87dc2 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Thu, 15 Sep 2022 19:20:11 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20Optimize?= =?UTF-8?q?=20Top=20P=20Sampler=20and=20fix=20edge=20case=20(#18984)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init PR * optimize top p and add edge case * styling * style * revert tf and flax test * add edge case test for FLAX and TF * update doc with smallest set sampling for top p * make style --- .../generation_flax_logits_process.py | 4 ++-- src/transformers/generation_logits_process.py | 15 ++++++--------- src/transformers/generation_tf_logits_process.py | 4 ++-- src/transformers/generation_utils.py | 4 ++-- .../test_generation_flax_logits_process.py | 4 ++-- .../generation/test_generation_logits_process.py | 4 ++-- .../test_generation_tf_logits_process.py | 7 +++++-- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index b41da1b9b2..6e8b4b6343 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -118,8 +118,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): Args: top_p (`float`): - If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept - for generation. + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. filter_value (`float`, *optional*, defaults to `-float("Inf")`): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 638815dced..35ca6c5731 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -173,8 +173,8 @@ class TopPLogitsWarper(LogitsWarper): Args: top_p (`float`): - If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept - for generation. + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. filter_value (`float`, *optional*, defaults to `-float("Inf")`): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): @@ -191,17 +191,14 @@ class TopPLogitsWarper(LogitsWarper): self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - sorted_logits, sorted_indices = torch.sort(scores, descending=True) + sorted_logits, sorted_indices = torch.sort(scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > self.top_p + sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index f17ed04686..b09330e10b 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -150,8 +150,8 @@ class TFTopPLogitsWarper(TFLogitsWarper): Args: top_p (`float`): - If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept - for generation. + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. filter_value (`float`, *optional*, defaults to `-float("Inf")`): All filtered values will be set to this float value. min_tokens_to_keep (`int`, *optional*, defaults to 1): diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index b0e7b14116..10f15304f4 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -990,8 +990,8 @@ class GenerationMixin: top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value): - If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher - are kept for generation. + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to + `top_p` or higher are kept for generation. typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value): The amount of probability mass from the original distribution to be considered in typical decoding. If set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. diff --git a/tests/generation/test_generation_flax_logits_process.py b/tests/generation/test_generation_flax_logits_process.py index aea44252d9..3a04f7e38f 100644 --- a/tests/generation/test_generation_flax_logits_process.py +++ b/tests/generation/test_generation_flax_logits_process.py @@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase): # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]])) - top_p_warp = FlaxTopPLogitsWarper(0.7) + top_p_warp = FlaxTopPLogitsWarper(0.8) filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None)) - # dist should be filtered to keep min num values so that sum is >= 0.7 + # dist should be filtered to keep min num values so that sum is >= top_p # exp (-inf) => 0 EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]]) self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index 7a515d3e92..396fccd009 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -169,10 +169,10 @@ class LogitsProcessorTest(unittest.TestCase): torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float) ) - top_p_warp = TopPLogitsWarper(0.7) + top_p_warp = TopPLogitsWarper(0.8) filtered_dist = torch.exp(top_p_warp(input_ids, dist)) - # dist should be filtered to keep min num values so that sum is >= 0.7 + # dist should be filtered to keep min num values so that sum is >= top_p # exp (-inf) => 0 EXPECTED_FILTERED_DIST = torch.tensor( [[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index be60335ef2..676a392204 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -189,12 +189,15 @@ class TFLogitsProcessorTest(unittest.TestCase): # create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper) dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32)) - top_p_warp = TFTopPLogitsWarper(0.7) + # top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob + # However, due to the numerical instability of softmax in TF we choose this as the edge case + # top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984. + top_p_warp = TFTopPLogitsWarper(0.79999995) if use_xla: top_p_warp = tf.function(top_p_warp, jit_compile=True) filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len)) - # dist should be filtered to keep min num values so that sum is >= 0.7 + # dist should be filtered to keep min num values so that sum is >= top_p # exp (-inf) => 0 EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32) tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)