From 3b00b623b7cad9e1b7c71c97fff24a0286b37045 Mon Sep 17 00:00:00 2001 From: unifyh <18213435+unifyh@users.noreply.github.com> Date: Wed, 22 Jun 2022 03:35:55 +0800 Subject: [PATCH] Fix `top_k_top_p_filtering` having unexpected behavior (#17744) - Fix `top_k_top_p_filtering` not passing `filter_value` to `TopPLogitsWarper` causing any top-p filtered logits to be -inf instead of specified value - Add corresponding test --- src/transformers/generation_utils.py | 4 +++- tests/generation/test_generation_utils.py | 26 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index be34b742fc..b93868409c 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -3347,6 +3347,8 @@ def top_k_top_p_filtering( ) if 0 <= top_p <= 1.0: - logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) + logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) return logits diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index c52a72450e..3b1aa7643f 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -1626,6 +1626,32 @@ class UtilsFunctionsTest(unittest.TestCase): self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) + # tests whether the function uses filter_value instead of default -inf + def test_top_k_top_p_filtering_with_filter_value(self): + logits = torch.tensor( + [ + [ + 1, + 1, + 1, + 0.99, # get filtered by top-p filtering + 0.98, # get filtered by top-k filtering + ] + ], + dtype=torch.float, + device=torch_device, + ) + + expected_output = torch.tensor( + [[1, 1, 1, 0, 0]], + dtype=torch.float, + device=torch_device, + ) + + output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0) + + self.assertTrue(torch.allclose(expected_output, output, atol=1e-12)) + @require_torch class GenerationIntegrationTests(unittest.TestCase):