Adding PrefixConstrainedLogitsProcessor (#8529)
* Adding PrefixConstrainedLogitsProcessor * fixing RAG and style_doc * fixing black (v20 instead of v19) * Improving doc in generation_logits_process.py * Improving docs and typing in generation_utils.py * docs improvement * adding test and fixing doc typo * fixing doc_len * isort on test * fixed test * improve docstring a bit Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
@@ -281,3 +282,23 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||
|
||||
def test_prefix_constrained_logits_processor(self):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
|
||||
return [[0, 1], [2, 3]][batch_id]
|
||||
|
||||
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)
|
||||
|
||||
filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())
|
||||
|
||||
# batch 1: 1st, 2nd (0, 1) token are allowed
|
||||
# batch 2: 3rd, 4th (2, 3) token are allowed
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user