Adding PrefixConstrainedLogitsProcessor (#8529)
* Adding PrefixConstrainedLogitsProcessor * fixing RAG and style_doc * fixing black (v20 instead of v19) * Improving doc in generation_logits_process.py * Improving docs and typing in generation_utils.py * docs improvement * adding test and fixing doc typo * fixing doc_len * isort on test * fixed test * improve docstring a bit Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
"""RAG model implementation."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1229,6 +1229,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
num_return_sequences=None,
|
||||
decoder_start_token_id=None,
|
||||
n_docs=None,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
@@ -1302,6 +1303,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
|
||||
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
|
||||
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
|
||||
argument is useful for constrained generation conditioned on the prefix, as described in
|
||||
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
@@ -1395,6 +1403,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
|
||||
Reference in New Issue
Block a user