Refactoring the generate() function (#6949)
* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
This commit is contained in:
committed by
GitHub
parent
b63beb743c
commit
a1bbcf3f6c
374
src/transformers/generation_logits_process.py
Normal file
374
src/transformers/generation_logits_process.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
|
||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
||||
or scores for each vocabulary token after SoftMax.
|
||||
|
||||
Return:
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Torch method for processing logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class LogitsWarper(ABC):
|
||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Torch method for warping logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class LogitsProcessorList(list):
|
||||
"""
|
||||
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
|
||||
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
|
||||
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
|
||||
:class:`~transformers.LogitsProcessor` to the inputs.
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
for processor in self:
|
||||
scores = processor(input_ids, scores)
|
||||
return scores
|
||||
|
||||
|
||||
class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
|
||||
|
||||
Args:
|
||||
min_length (:obj:`int`):
|
||||
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, min_length: int, eos_token_id: int):
|
||||
if not isinstance(min_length, int) or min_length < 0:
|
||||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||
|
||||
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
||||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
||||
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len < self.min_length:
|
||||
scores[:, self.eos_token_id] = -float("inf")
|
||||
return scores
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
|
||||
|
||||
Args:
|
||||
temperature (:obj:`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
|
||||
|
||||
Args:
|
||||
repetition_penalty (:obj:`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
|
||||
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
for i in range(scores.shape[0]):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
if scores[i, previous_token] < 0:
|
||||
scores[i, previous_token] *= self.penalty
|
||||
else:
|
||||
scores[i, previous_token] /= self.penalty
|
||||
return scores
|
||||
|
||||
|
||||
class TopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
|
||||
prob_cut_off.
|
||||
|
||||
Args:
|
||||
top_p (:obj:`float`):
|
||||
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
|
||||
kept for generation.
|
||||
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > self.top_p
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores[indices_to_remove] = self.filter_value
|
||||
return scores
|
||||
|
||||
|
||||
class TopKLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
|
||||
Args:
|
||||
top_k (:obj:`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
||||
|
||||
self.top_k = top_k
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||
scores[indices_to_remove] = self.filter_value
|
||||
return scores
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
|
||||
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
|
||||
|
||||
Args:
|
||||
ngram_size (:obj:`int`):
|
||||
All ngrams of size :obj:`ngram_size` can only occur once.
|
||||
"""
|
||||
|
||||
def __init__(self, ngram_size: int):
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
||||
self.ngram_size = ngram_size
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
num_batch_hypotheses = scores.shape[0]
|
||||
cur_len = input_ids.shape[-1]
|
||||
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
|
||||
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
def _calc_banned_ngram_tokens(
|
||||
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
||||
) -> List[Iterable[int]]:
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if cur_len + 1 < self.ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_input_ids[idx].tolist()
|
||||
generated_ngram = generated_ngrams[idx]
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
|
||||
prev_ngram_tuple = tuple(ngram[:-1])
|
||||
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = cur_len + 1 - self.ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
return banned_tokens
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
|
||||
|
||||
Args:
|
||||
bad_words_ids (:obj:`List[List[int]]`):
|
||||
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
|
||||
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
||||
add_prefix_space=True).input_ids`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
|
||||
|
||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
||||
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
|
||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
||||
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
||||
if any(
|
||||
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
|
||||
for bad_word_ids in bad_words_ids
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||
)
|
||||
|
||||
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
||||
|
||||
for banned_token_seq in self.bad_words_ids:
|
||||
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
|
||||
bad_words_ids
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
|
||||
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
||||
|
||||
return scores
|
||||
|
||||
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
|
||||
if len(tokens) == 0:
|
||||
# if bad word tokens is just one token always ban it
|
||||
return True
|
||||
elif len(tokens) > len(prev_tokens):
|
||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
||||
return False
|
||||
elif prev_tokens[-len(tokens) :].tolist() == tokens:
|
||||
# if tokens match
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
|
||||
banned_tokens = []
|
||||
for prev_input_ids_slice in prev_input_ids:
|
||||
banned_tokens_slice = []
|
||||
for banned_token_seq in self.bad_words_ids:
|
||||
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
||||
# if tokens do not match continue
|
||||
continue
|
||||
|
||||
banned_tokens_slice.append(banned_token_seq[-1])
|
||||
|
||||
banned_tokens.append(banned_tokens_slice)
|
||||
|
||||
return banned_tokens
|
||||
|
||||
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
||||
"""
|
||||
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
|
||||
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
|
||||
|
||||
Args:
|
||||
scores: logits distribution of shape (batch size, vocabulary size)
|
||||
banned_tokens: list of list of tokens to ban of length (batch_size)
|
||||
"""
|
||||
banned_mask_list = []
|
||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||
for token in batch_banned_tokens:
|
||||
banned_mask_list.append([idx, token])
|
||||
if not banned_mask_list:
|
||||
return scores
|
||||
|
||||
banned_mask = torch.LongTensor(banned_mask_list)
|
||||
indices = torch.ones(len(banned_mask))
|
||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||
# [ 0 1 1 ]
|
||||
# [ 0 0 0 ]
|
||||
# [ 1 0 0 ]
|
||||
|
||||
banned_mask = (
|
||||
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
||||
)
|
||||
scores = scores.masked_fill(banned_mask, -float("inf"))
|
||||
return scores
|
||||
Reference in New Issue
Block a user