Fix: Raise informative exception when prefix_allowed_tokens_fn return empty set of tokens (#27797)

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Saibo-creator
2023-12-08 18:25:49 +08:00
committed by GitHub
parent 307a7d0be8
commit 56be5e80e6
2 changed files with 15 additions and 1 deletions

View File

@@ -610,6 +610,13 @@ class LogitsProcessorTest(unittest.TestCase):
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)
def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids):
return []
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1)
self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone())
def test_hamming_diversity(self):
vocab_size = 4
num_beams = 2