Adding PrefixConstrainedLogitsProcessor (#8529)
* Adding PrefixConstrainedLogitsProcessor * fixing RAG and style_doc * fixing black (v20 instead of v19) * Improving doc in generation_logits_process.py * Improving docs and typing in generation_utils.py * docs improvement * adding test and fixing doc typo * fixing doc_len * isort on test * fixed test * improve docstring a bit Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from abc import ABC
|
||||
from typing import Iterable, List
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -372,3 +373,30 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
)
|
||||
scores = scores.masked_fill(banned_mask, -float("inf"))
|
||||
return scores
|
||||
|
||||
|
||||
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
|
||||
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
|
||||
information.
|
||||
|
||||
Args:
|
||||
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
|
||||
This function constraints the beam search to allowed tokens only at each step. This function takes 2
|
||||
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
|
||||
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
|
||||
the batch ID :obj:`batch_id`.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
|
||||
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
|
||||
self._num_beams = num_beams
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.full_like(scores, -math.inf)
|
||||
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
|
||||
for beam_id, sent in enumerate(beam_sent):
|
||||
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
|
||||
|
||||
return scores + mask
|
||||
|
||||
Reference in New Issue
Block a user