Adding PrefixConstrainedLogitsProcessor (#8529)

* Adding PrefixConstrainedLogitsProcessor

* fixing RAG and style_doc

* fixing black (v20 instead of v19)

* Improving doc in generation_logits_process.py

* Improving docs and typing in generation_utils.py

* docs improvement

* adding test and fixing doc typo

* fixing doc_len

* isort on test

* fixed test

* improve docstring a bit

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nicola De Cao
2020-11-18 16:06:25 +00:00
committed by GitHub
parent 3bc1540070
commit 2f9d49b389
4 changed files with 77 additions and 3 deletions

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import torch
from torch.nn import functional as F
@@ -26,6 +26,7 @@ from .generation_logits_process import (
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
@@ -258,6 +259,8 @@ class GenerationMixin:
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
@@ -285,6 +288,8 @@ class GenerationMixin:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
return processors
@torch.no_grad()
@@ -309,6 +314,7 @@ class GenerationMixin:
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
@@ -375,6 +381,13 @@ class GenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
@@ -494,6 +507,8 @@ class GenerationMixin:
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)
if is_greedy_gen_mode: