Fix: Raise informative exception when prefix_allowed_tokens_fn return empty set of tokens (#27797)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1229,7 +1229,14 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
|||||||
mask = torch.full_like(scores, -math.inf)
|
mask = torch.full_like(scores, -math.inf)
|
||||||
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
|
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
|
||||||
for beam_id, sent in enumerate(beam_sent):
|
for beam_id, sent in enumerate(beam_sent):
|
||||||
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
|
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
|
||||||
|
if len(prefix_allowed_tokens) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
|
||||||
|
f"This means that the constraint is unsatisfiable. Please check your implementation"
|
||||||
|
f"of `prefix_allowed_tokens_fn` "
|
||||||
|
)
|
||||||
|
mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
|
||||||
|
|
||||||
return scores + mask
|
return scores + mask
|
||||||
|
|
||||||
|
|||||||
@@ -610,6 +610,13 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
|
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids):
|
||||||
|
return []
|
||||||
|
|
||||||
|
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone())
|
||||||
|
|
||||||
def test_hamming_diversity(self):
|
def test_hamming_diversity(self):
|
||||||
vocab_size = 4
|
vocab_size = 4
|
||||||
num_beams = 2
|
num_beams = 2
|
||||||
|
|||||||
Reference in New Issue
Block a user