Fix top_k_top_p_filtering having unexpected behavior (#17744)
- Fix `top_k_top_p_filtering` not passing `filter_value` to `TopPLogitsWarper` causing any top-p filtered logits to be -inf instead of specified value - Add corresponding test
This commit is contained in:
@@ -3347,6 +3347,8 @@ def top_k_top_p_filtering(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if 0 <= top_p <= 1.0:
|
if 0 <= top_p <= 1.0:
|
||||||
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
|
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||||
|
None, logits
|
||||||
|
)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -1626,6 +1626,32 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||||
|
|
||||||
|
# tests whether the function uses filter_value instead of default -inf
|
||||||
|
def test_top_k_top_p_filtering_with_filter_value(self):
|
||||||
|
logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0.99, # get filtered by top-p filtering
|
||||||
|
0.98, # get filtered by top-k filtering
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=torch.float,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output = torch.tensor(
|
||||||
|
[[1, 1, 1, 0, 0]],
|
||||||
|
dtype=torch.float,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GenerationIntegrationTests(unittest.TestCase):
|
class GenerationIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user