🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)
* init PR * optimize top p and add edge case * styling * style * revert tf and flax test * add edge case test for FLAX and TF * update doc with smallest set sampling for top p * make style
This commit is contained in:
@@ -118,8 +118,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
for generation.
|
higher are kept for generation.
|
||||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
All filtered values will be set to this float value.
|
All filtered values will be set to this float value.
|
||||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
|||||||
@@ -173,8 +173,8 @@ class TopPLogitsWarper(LogitsWarper):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
for generation.
|
higher are kept for generation.
|
||||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
All filtered values will be set to this float value.
|
All filtered values will be set to this float value.
|
||||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
@@ -191,17 +191,14 @@ class TopPLogitsWarper(LogitsWarper):
|
|||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = cumulative_probs > self.top_p
|
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||||
if self.min_tokens_to_keep > 1:
|
if self.min_tokens_to_keep > 1:
|
||||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
# Keep at least min_tokens_to_keep
|
||||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
# Shift the indices to the right to keep also the first token above the threshold
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
|||||||
@@ -150,8 +150,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
for generation.
|
higher are kept for generation.
|
||||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
All filtered values will be set to this float value.
|
All filtered values will be set to this float value.
|
||||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
|||||||
@@ -990,8 +990,8 @@ class GenerationMixin:
|
|||||||
top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
|
top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
|
||||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
|
top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
|
||||||
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
|
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
|
||||||
are kept for generation.
|
`top_p` or higher are kept for generation.
|
||||||
typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
|
typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
|
||||||
The amount of probability mass from the original distribution to be considered in typical decoding. If
|
The amount of probability mass from the original distribution to be considered in typical decoding. If
|
||||||
set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
|
set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
|
||||||
|
|||||||
@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||||
|
|
||||||
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||||
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
||||||
|
|
||||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||||
# exp (-inf) => 0
|
# exp (-inf) => 0
|
||||||
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
||||||
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||||
|
|||||||
@@ -169,10 +169,10 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
|
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
|
||||||
)
|
)
|
||||||
|
|
||||||
top_p_warp = TopPLogitsWarper(0.7)
|
top_p_warp = TopPLogitsWarper(0.8)
|
||||||
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
|
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
|
||||||
|
|
||||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||||
# exp (-inf) => 0
|
# exp (-inf) => 0
|
||||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||||
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
|
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
|
||||||
|
|||||||
@@ -189,12 +189,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
|
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
|
||||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
|
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
|
||||||
|
|
||||||
top_p_warp = TFTopPLogitsWarper(0.7)
|
# top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob
|
||||||
|
# However, due to the numerical instability of softmax in TF we choose this as the edge case
|
||||||
|
# top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984.
|
||||||
|
top_p_warp = TFTopPLogitsWarper(0.79999995)
|
||||||
if use_xla:
|
if use_xla:
|
||||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
|
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
|
||||||
|
|
||||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||||
# exp (-inf) => 0
|
# exp (-inf) => 0
|
||||||
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
|
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
|
||||||
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
|
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
|
||||||
|
|||||||
Reference in New Issue
Block a user