* add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review
569 lines
25 KiB
Python
569 lines
25 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The HuggingFace Inc. team
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import math
|
|
from abc import ABC
|
|
from typing import Callable, Iterable, List
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .file_utils import add_start_docstrings
|
|
|
|
|
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
|
details.
|
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__
|
|
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
|
|
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
|
or scores for each vocabulary token after SoftMax.
|
|
kwargs:
|
|
Additional logits processor specific kwargs.
|
|
|
|
Return:
|
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
|
|
|
"""
|
|
|
|
|
|
class LogitsProcessor(ABC):
|
|
"""Abstract base class for all logit processors that can be applied during generation."""
|
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
"""Torch method for processing logits."""
|
|
raise NotImplementedError(
|
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
|
)
|
|
|
|
|
|
class LogitsWarper(ABC):
|
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
"""Torch method for warping logits."""
|
|
raise NotImplementedError(
|
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
|
)
|
|
|
|
|
|
class LogitsProcessorList(list):
|
|
"""
|
|
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
|
|
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
|
|
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
|
|
:class:`~transformers.LogitsProcessor` to the inputs.
|
|
"""
|
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
|
for processor in self:
|
|
function_args = inspect.signature(processor.__call__).parameters
|
|
if len(function_args) > 2:
|
|
assert all(
|
|
arg in kwargs for arg in list(function_args.keys())[2:]
|
|
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
|
scores = processor(input_ids, scores, **kwargs)
|
|
else:
|
|
scores = processor(input_ids, scores)
|
|
return scores
|
|
|
|
|
|
class MinLengthLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
|
|
|
|
Args:
|
|
min_length (:obj:`int`):
|
|
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
|
|
eos_token_id (:obj:`int`):
|
|
The id of the `end-of-sequence` token.
|
|
"""
|
|
|
|
def __init__(self, min_length: int, eos_token_id: int):
|
|
if not isinstance(min_length, int) or min_length < 0:
|
|
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
|
|
|
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
|
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
|
|
|
self.min_length = min_length
|
|
self.eos_token_id = eos_token_id
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
cur_len = input_ids.shape[-1]
|
|
if cur_len < self.min_length:
|
|
scores[:, self.eos_token_id] = -float("inf")
|
|
return scores
|
|
|
|
|
|
class TemperatureLogitsWarper(LogitsWarper):
|
|
r"""
|
|
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
|
|
|
|
Args:
|
|
temperature (:obj:`float`):
|
|
The value used to module the logits distribution.
|
|
"""
|
|
|
|
def __init__(self, temperature: float):
|
|
if not isinstance(temperature, float) or not (temperature > 0):
|
|
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
|
|
|
self.temperature = temperature
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
|
scores = scores / self.temperature
|
|
return scores
|
|
|
|
|
|
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
|
|
|
|
Args:
|
|
repetition_penalty (:obj:`float`):
|
|
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
|
|
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
|
"""
|
|
|
|
def __init__(self, penalty: float):
|
|
if not isinstance(penalty, float) or not (penalty > 0):
|
|
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
|
|
|
self.penalty = penalty
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
score = torch.gather(scores, 1, input_ids)
|
|
|
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
|
|
|
scores.scatter_(1, input_ids, score)
|
|
return scores
|
|
|
|
|
|
class TopPLogitsWarper(LogitsWarper):
|
|
"""
|
|
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
|
|
prob_cut_off.
|
|
|
|
Args:
|
|
top_p (:obj:`float`):
|
|
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
|
|
kept for generation.
|
|
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
|
All filtered values will be set to this float value.
|
|
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
|
Minimum number of tokens that cannot be filtered.
|
|
"""
|
|
|
|
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
|
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
|
|
|
self.top_p = top_p
|
|
self.filter_value = filter_value
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
|
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
sorted_indices_to_remove = cumulative_probs > self.top_p
|
|
if self.min_tokens_to_keep > 1:
|
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
|
|
# Shift the indices to the right to keep also the first token above the threshold
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
# scatter sorted tensors to original indexing
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
return scores
|
|
|
|
|
|
class TopKLogitsWarper(LogitsWarper):
|
|
r"""
|
|
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
|
|
|
|
Args:
|
|
top_k (:obj:`int`):
|
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
|
All filtered values will be set to this float value.
|
|
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
|
Minimum number of tokens that cannot be filtered.
|
|
"""
|
|
|
|
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
if not isinstance(top_k, int) or top_k <= 0:
|
|
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
|
|
|
self.top_k = top_k
|
|
self.filter_value = filter_value
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
|
|
# Remove all tokens with a probability less than the last token of the top-k
|
|
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
return scores
|
|
|
|
|
|
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
|
generated_ngrams = [{} for _ in range(num_hypos)]
|
|
for idx in range(num_hypos):
|
|
gen_tokens = prev_input_ids[idx].tolist()
|
|
generated_ngram = generated_ngrams[idx]
|
|
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
|
|
prev_ngram_tuple = tuple(ngram[:-1])
|
|
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
|
return generated_ngrams
|
|
|
|
|
|
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
|
|
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
|
start_idx = cur_len + 1 - ngram_size
|
|
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
|
|
return banned_ngrams.get(ngram_idx, [])
|
|
|
|
|
|
def _calc_banned_ngram_tokens(
|
|
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
|
) -> List[Iterable[int]]:
|
|
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
|
if cur_len + 1 < ngram_size:
|
|
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
|
return [[] for _ in range(num_hypos)]
|
|
|
|
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
|
|
|
|
banned_tokens = [
|
|
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
|
|
for hypo_idx in range(num_hypos)
|
|
]
|
|
return banned_tokens
|
|
|
|
|
|
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
|
|
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
|
|
|
|
Args:
|
|
ngram_size (:obj:`int`):
|
|
All ngrams of size :obj:`ngram_size` can only occur once.
|
|
"""
|
|
|
|
def __init__(self, ngram_size: int):
|
|
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
|
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
|
self.ngram_size = ngram_size
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
num_batch_hypotheses = scores.shape[0]
|
|
cur_len = input_ids.shape[-1]
|
|
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
|
|
|
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
|
scores[i, banned_tokens] = -float("inf")
|
|
|
|
return scores
|
|
|
|
|
|
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids.
|
|
See `ParlAI <https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350>`__.
|
|
|
|
Args:
|
|
encoder_ngram_size (:obj:`int`):
|
|
All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids.
|
|
encoder_input_ids (:obj:`int`):
|
|
The encoder_input_ids that should not be repeated within the decoder ids.
|
|
"""
|
|
|
|
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
|
|
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
|
|
raise ValueError(
|
|
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
|
|
)
|
|
self.ngram_size = encoder_ngram_size
|
|
if len(encoder_input_ids.shape) == 1:
|
|
encoder_input_ids = encoder_input_ids.unsqueeze(0)
|
|
self.batch_size = encoder_input_ids.shape[0]
|
|
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
# B x num_beams
|
|
num_hypos = scores.shape[0]
|
|
num_beams = num_hypos // self.batch_size
|
|
cur_len = input_ids.shape[-1]
|
|
banned_batch_tokens = [
|
|
_get_generated_ngrams(
|
|
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
|
|
)
|
|
for hypo_idx in range(num_hypos)
|
|
]
|
|
|
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
|
scores[i, banned_tokens] = -float("inf")
|
|
|
|
return scores
|
|
|
|
|
|
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|
"""
|
|
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
|
|
|
|
Args:
|
|
bad_words_ids (:obj:`List[List[int]]`):
|
|
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
|
|
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
|
add_prefix_space=True).input_ids`.
|
|
eos_token_id (:obj:`int`):
|
|
The id of the `end-of-sequence` token.
|
|
"""
|
|
|
|
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
|
|
|
|
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
|
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
|
|
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
|
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
|
if any(
|
|
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
|
|
for bad_word_ids in bad_words_ids
|
|
):
|
|
raise ValueError(
|
|
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
|
)
|
|
|
|
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
|
|
|
for banned_token_seq in self.bad_words_ids:
|
|
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
|
|
bad_words_ids
|
|
)
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
|
|
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
|
|
|
return scores
|
|
|
|
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
|
|
if len(tokens) == 0:
|
|
# if bad word tokens is just one token always ban it
|
|
return True
|
|
elif len(tokens) > len(prev_tokens):
|
|
# if bad word tokens are longer then prev input_ids they can't be equal
|
|
return False
|
|
elif prev_tokens[-len(tokens) :].tolist() == tokens:
|
|
# if tokens match
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
|
|
banned_tokens = []
|
|
for prev_input_ids_slice in prev_input_ids:
|
|
banned_tokens_slice = []
|
|
for banned_token_seq in self.bad_words_ids:
|
|
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
|
# if tokens do not match continue
|
|
continue
|
|
|
|
banned_tokens_slice.append(banned_token_seq[-1])
|
|
|
|
banned_tokens.append(banned_tokens_slice)
|
|
|
|
return banned_tokens
|
|
|
|
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
|
"""
|
|
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
|
|
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
|
|
|
|
Args:
|
|
scores: logits distribution of shape (batch size, vocabulary size)
|
|
banned_tokens: list of list of tokens to ban of length (batch_size)
|
|
"""
|
|
banned_mask_list = []
|
|
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
|
for token in batch_banned_tokens:
|
|
banned_mask_list.append([idx, token])
|
|
if not banned_mask_list:
|
|
return scores
|
|
|
|
banned_mask = torch.LongTensor(banned_mask_list)
|
|
indices = torch.ones(len(banned_mask))
|
|
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
|
# [ 0 1 1 ]
|
|
# [ 0 0 0 ]
|
|
# [ 1 0 0 ]
|
|
|
|
banned_mask = (
|
|
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
|
)
|
|
scores = scores.masked_fill(banned_mask, -float("inf"))
|
|
return scores
|
|
|
|
|
|
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
|
|
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
|
|
information.
|
|
|
|
Args:
|
|
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
|
|
This function constraints the beam search to allowed tokens only at each step. This function takes 2
|
|
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
|
|
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
|
|
the batch ID :obj:`batch_id`.
|
|
"""
|
|
|
|
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
|
|
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
|
|
self._num_beams = num_beams
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
mask = torch.full_like(scores, -math.inf)
|
|
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
|
|
for beam_id, sent in enumerate(beam_sent):
|
|
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
|
|
|
|
return scores + mask
|
|
|
|
|
|
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only
|
|
effective for :meth:`transformers.PretrainedModel.group_beam_search`. See `Diverse Beam Search: Decoding Diverse
|
|
Solutions from Neural Sequence Models <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
|
|
|
Args:
|
|
diversity_penalty (:obj:`float`):
|
|
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
|
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled.
|
|
num_beams (:obj:`int`):
|
|
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for
|
|
more details.
|
|
num_beam_groups (:obj:`int`):
|
|
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
|
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
|
"""
|
|
|
|
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
|
|
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
|
|
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
|
|
self._diversity_penalty = diversity_penalty
|
|
if not isinstance(num_beams, int) or num_beams < 2:
|
|
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
|
|
self._num_beams = num_beams
|
|
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
|
|
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
|
|
if num_beam_groups > num_beams:
|
|
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
|
|
self._num_sub_beams = num_beams // num_beam_groups
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
scores: torch.FloatTensor,
|
|
current_tokens: torch.LongTensor,
|
|
beam_group_idx: int,
|
|
) -> torch.FloatTensor:
|
|
# hamming diversity: penalise using same token in current group which was used in previous groups at
|
|
# the same time step
|
|
batch_size = current_tokens.shape[0] // self._num_beams
|
|
group_start_idx = beam_group_idx * self._num_sub_beams
|
|
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
|
|
group_size = group_end_idx - group_start_idx
|
|
vocab_size = scores.shape[-1]
|
|
|
|
if group_start_idx == 0:
|
|
return scores
|
|
|
|
for batch_idx in range(batch_size):
|
|
# predicted tokens of last time step of previous groups
|
|
previous_group_tokens = current_tokens[
|
|
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
|
|
]
|
|
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
|
|
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
|
|
|
|
return scores
|
|
|
|
|
|
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token.
|
|
|
|
Args:
|
|
bos_token_id (:obj:`int`):
|
|
The id of the token to force as the first generated token.
|
|
"""
|
|
|
|
def __init__(self, bos_token_id: int):
|
|
self.bos_token_id = bos_token_id
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
cur_len = input_ids.shape[-1]
|
|
if cur_len == 1:
|
|
num_tokens = scores.shape[1]
|
|
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
|
|
scores[:, self.bos_token_id] = 0
|
|
return scores
|
|
|
|
|
|
class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
|
r"""
|
|
:class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when
|
|
:obj:`max_length` is reached.
|
|
|
|
Args:
|
|
max_length (:obj:`int`):
|
|
The maximum length of the sequence to be generated.
|
|
eos_token_id (:obj:`int`):
|
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
|
"""
|
|
|
|
def __init__(self, max_length: int, eos_token_id: int):
|
|
self.max_length = max_length
|
|
self.eos_token_id = eos_token_id
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
cur_len = input_ids.shape[-1]
|
|
if cur_len == self.max_length - 1:
|
|
num_tokens = scores.shape[1]
|
|
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
|
|
scores[:, self.eos_token_id] = 0
|
|
return scores
|