Adding PrefixConstrainedLogitsProcessor (#8529)

* Adding PrefixConstrainedLogitsProcessor

* fixing RAG and style_doc

* fixing black (v20 instead of v19)

* Improving doc in generation_logits_process.py

* Improving docs and typing in generation_utils.py

* docs improvement

* adding test and fixing doc typo

* fixing doc_len

* isort on test

* fixed test

* improve docstring a bit

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nicola De Cao
2020-11-18 16:06:25 +00:00
committed by GitHub
parent 3bc1540070
commit 2f9d49b389
4 changed files with 77 additions and 3 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
@@ -281,3 +282,23 @@ class LogitsProcessorTest(unittest.TestCase):
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def test_prefix_constrained_logits_processor(self):
vocab_size = 5
batch_size = 2
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
return [[0, 1], [2, 3]][batch_id]
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)
filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())
# batch 1: 1st, 2nd (0, 1) token are allowed
# batch 2: 3rd, 4th (2, 3) token are allowed
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)