Adding PrefixConstrainedLogitsProcessor (#8529)

* Adding PrefixConstrainedLogitsProcessor

* fixing RAG and style_doc

* fixing black (v20 instead of v19)

* Improving doc in generation_logits_process.py

* Improving docs and typing in generation_utils.py

* docs improvement

* adding test and fixing doc typo

* fixing doc_len

* isort on test

* fixed test

* improve docstring a bit

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nicola De Cao
2020-11-18 16:06:25 +00:00
committed by GitHub
parent 3bc1540070
commit 2f9d49b389
4 changed files with 77 additions and 3 deletions

View File

@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from abc import ABC
from typing import Iterable, List
from typing import Callable, Iterable, List
import numpy as np
import torch
@@ -372,3 +373,30 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
information.
Args:
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
This function constraints the beam search to allowed tokens only at each step. This function takes 2
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
the batch ID :obj:`batch_id`.
"""
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
return scores + mask