From 5f0801d174a69e38d28eeba47e77545a27a260be Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 Jun 2023 11:14:41 +0100 Subject: [PATCH] Generate: add SequenceBiasLogitsProcessor (#24334) --- docs/source/en/internal/generation_utils.md | 3 + src/transformers/__init__.py | 2 + src/transformers/generation/__init__.py | 2 + .../generation/configuration_utils.py | 12 +- src/transformers/generation/logits_process.py | 310 +++++++++++------- src/transformers/generation/utils.py | 6 +- src/transformers/utils/dummy_pt_objects.py | 7 + tests/generation/test_logits_process.py | 22 ++ 8 files changed, 241 insertions(+), 123 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index c158be36ba..f5c882bf12 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -141,6 +141,9 @@ generation. [[autodoc]] NoRepeatNGramLogitsProcessor - __call__ +[[autodoc]] SequenceBiasLogitsProcessor + - __call__ + [[autodoc]] NoBadWordsLogitsProcessor - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1254761a2a..3217a9ca04 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -970,6 +970,7 @@ else: "PhrasalConstraint", "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", + "SequenceBiasLogitsProcessor", "StoppingCriteria", "StoppingCriteriaList", "TemperatureLogitsWarper", @@ -4733,6 +4734,7 @@ if TYPE_CHECKING: PhrasalConstraint, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, StoppingCriteria, StoppingCriteriaList, TemperatureLogitsWarper, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index bf87b6e5ff..0a522e9bb7 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -56,6 +56,7 @@ else: "NoRepeatNGramLogitsProcessor", "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", + "SequenceBiasLogitsProcessor", "EncoderRepetitionPenaltyLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", @@ -182,6 +183,7 @@ if TYPE_CHECKING: NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 26246f3a75..20caed9cc5 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin): no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. bad_words_ids(`List[List[int]]`, *optional*): - List of token ids that are not allowed to be generated. In order to get the token ids of the words that - should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing the - tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space` - argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from - `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers). + List of list of token ids that are not allowed to be generated. Check + [`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples. force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this @@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. + sequence_bias (`Dict[Tuple[int], float]`, *optional*)): + Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the + sequence being selected, while negative biases do the opposite. Check + [`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples. > Parameters that define the output variables of `generate` @@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin): self.suppress_tokens = kwargs.pop("suppress_tokens", None) self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) + self.sequence_bias = kwargs.pop("sequence_bias", None) # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 8595a827ef..1b32e592dc 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Tuple, Union import numpy as np import torch @@ -539,23 +539,208 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): return scores -class NoBadWordsLogitsProcessor(LogitsProcessor): +class SequenceBiasLogitsProcessor(LogitsProcessor): """ - [`LogitsProcessor`] that enforces that specified sequences will never be sampled. + [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence + when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than + one token, consider using beam methods (to gracefully work around partially completed sequences that have a + negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier). + + + + In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when + initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The + `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours + come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers). + + + + Args: + sequence_bias (`Dict[Tuple[int], float]`): + Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the + sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias + will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be + completed (in the token selection step after this processor is applied). + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt") + + >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4) + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald J. Trump Jr + + >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently! + >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True) + + + >>> def get_tokens_as_tuple(word): + ... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]) + + + >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations + >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0} + >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias) + >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald J. Donald, + + >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) + >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald Rumsfeld, + + >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations + >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0} + >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) + >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald Duck. + ``` + """ + + def __init__(self, sequence_bias: Dict[Tuple[int], float]): + self.sequence_bias = sequence_bias + self._validate_arguments() + + # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size + # is infered in the first usage, which inhibits initializing here) + self.sequences_length_greater_than_1 = [] + self.length_1_bias = None + self.length_greather_than_1_bias = None + self.prepared_bias_variables = False + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called. + if not self.prepared_bias_variables: + self._prepare_bias_variables(scores) + + # 2 - prepares an empty bias to add + bias = torch.zeros_like(scores) + + # 3 - include the bias from length = 1 + bias += self.length_1_bias + + # 4 - include the bias from length > 1, after determining which biased sequences may be completed. + # `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding + # bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence + # may become complete this iteration. + matching_mask = torch.zeros_like(scores, dtype=torch.bool) + for sequence_ids in self.sequences_length_greater_than_1: + if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore + continue + prefix_length = len(sequence_ids) - 1 + last_token = sequence_ids[-1] + matching_rows = torch.eq( + input_ids[:, -prefix_length:], + torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device), + ).prod(dim=1) + matching_mask[:, last_token] |= matching_rows.bool() + bias += torch.where(matching_mask, self.length_greather_than_1_bias, 0.0) + + # 5 - apply the bias to the scores + scores = scores + bias + return scores + + def _prepare_bias_variables(self, scores: torch.FloatTensor): + vocabulary_size = scores.shape[-1] + sequence_bias = self.sequence_bias + tokens_with_bias = [] + + # Check biased tokens out of bounds + invalid_biases = [] + for sequence_ids in sequence_bias: + for token_id in sequence_ids: + if token_id >= vocabulary_size: + invalid_biases.append(token_id) + if len(invalid_biases) > 0: + raise ValueError( + f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: " + f"{invalid_biases}" + ) + + # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied + # with simpler logic. + self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) + self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) + for sequence_ids, bias in sequence_bias.items(): + if len(sequence_ids) == 1: + self.length_1_bias[sequence_ids[-1]] = bias + else: + self.sequences_length_greater_than_1.append(sequence_ids) + if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0: + raise ValueError( + "Setting a bias on sequences that share a common token termination is not yet supported. " + "Please open an issue if you see this error message (after checking that it doesn't already " + "exist)." + ) + self.length_greather_than_1_bias[sequence_ids[-1]] = bias + tokens_with_bias.append(sequence_ids[-1]) + + self.prepared_bias_variables = True + + def _validate_arguments(self): + sequence_bias = self.sequence_bias + if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0: + raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.") + if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()): + raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.") + if any( + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids) + or len(sequence_ids) == 0 + for sequence_ids in sequence_bias.keys() + ): + raise ValueError( + f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is " + f"{sequence_bias}." + ) + if any(not isinstance(bias, float) for bias in sequence_bias.values()): + raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.") + + +class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): + """ + [`LogitsProcessor`] that enforces that specified sequences will never be selected. + + + + In order to get the token ids of the words that should not appear in the generated text, make sure to set + `add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words, + add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers, + as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more + [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers). + + Args: bad_words_ids (`List[List[int]]`): - List of list of token ids that are not allowed to be generated. In order to get the token ids of the words - that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing - the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space` - argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from - `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers). + List of list of token ids that are not allowed to be generated. eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): - if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: + self.bad_word_ids = bad_words_ids + self._validate_arguments() + + # Filter EOS token from bad_words_ids + if eos_token_id is None: + eos_token_id = [] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + bad_words_ids = list( + filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) + ) + + # Forbidding a sequence is equivalent to setting its bias to -inf + sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids} + super().__init__(sequence_bias=sequence_bias) + + def _validate_arguments(self): + bad_words_ids = self.bad_word_ids + if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0: raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") @@ -567,113 +752,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." ) - if eos_token_id is None: - eos_token_id = [] - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - bad_words_ids = list( - filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) - ) - self.bad_words_id_length_1 = [] - self.bad_words_id_length_greater_than_1 = [] - for word in bad_words_ids: - if len(word) == 1: - self.bad_words_id_length_1.append(word[0]) - else: - self.bad_words_id_length_greater_than_1.append(word) - - self.static_bad_words_mask: Optional[torch.LongTensor] = None - - for banned_token_seq in self.bad_words_id_length_greater_than_1: - if len(banned_token_seq) == 0: - raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0: - self.static_bad_words_mask = self._calc_static_bad_word_mask(scores) - - dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist()) - scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens) - - return scores - - def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor: - static_bad_words_mask = torch.zeros(scores.shape[1]) - static_bad_words_mask[self.bad_words_id_length_1] = 1 - return static_bad_words_mask.unsqueeze(0).to(scores.device).bool() - - def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool: - if len(tokens) == 0: - # if bad word tokens is just one token always ban it - return True - elif len(tokens) > len(prev_tokens): - # if bad word tokens are longer then prev input_ids they can't be equal - return False - else: - return prev_tokens[-len(tokens) :] == tokens - - def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]: - banned_tokens = [] - for prev_input_ids_slice in prev_input_ids: - banned_tokens_slice = [] - for banned_token_seq in self.bad_words_id_length_greater_than_1: - if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]): - banned_tokens_slice.append(banned_token_seq[-1]) - - banned_tokens.append(banned_tokens_slice) - - return banned_tokens - - def _set_scores_to_inf_for_banned_tokens( - self, scores: torch.Tensor, banned_tokens: List[List[int]] - ) -> torch.Tensor: - """ - Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a - list of list of banned tokens to ban in the format [[batch index, vocabulary position],... - - Args: - scores: logits distribution of shape (batch size, vocabulary size) - banned_tokens: list of list of tokens to ban of length (batch_size) - """ - banned_mask_list = [] - for idx, batch_banned_tokens in enumerate(banned_tokens): - for token in batch_banned_tokens: - # Eliminates invalid bad word IDs that are over the vocabulary size. - if token <= scores.shape[1]: - banned_mask_list.append([idx, token]) - else: - logger.error( - f"An invalid bad word ID is defined: {token}. This ID is not contained in the " - "vocabulary, and is therefore ignored." - ) - if not banned_mask_list and self.static_bad_words_mask is None: - return scores - - else: - if banned_mask_list: - indices = torch.ones(len(banned_mask_list)) - banned_mask = torch.LongTensor(banned_mask_list, device=indices.device) - # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: - # [ 0 1 1 ] - # [ 0 0 0 ] - # [ 1 0 0 ] - - banned_mask = ( - torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) - .to(scores.device) - .to_dense() - .bool() - ) - - if self.static_bad_words_mask is not None: - banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask) - else: - banned_mask = self.static_bad_words_mask - - scores = scores.masked_fill(banned_mask, -float("inf")) - return scores - class PrefixConstrainedLogitsProcessor(LogitsProcessor): r""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eed177af9e..e5da7a143b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -56,6 +56,7 @@ from .logits_process import ( NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, TemperatureLogitsWarper, @@ -842,8 +843,9 @@ class GenerationMixin: # instantiate processors list processors = LogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` + if generation_config.sequence_bias is not None: + processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) + if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 11e763e2a0..5971244e66 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["torch"]) +class SequenceBiasLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class StoppingCriteria(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 85ebf780f7..e560692d4c 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -46,6 +46,7 @@ if is_torch_available(): NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, @@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase): filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3)) + def test_bias_dist_processor(self): + vocab_size = 5 + batch_size = 2 + + input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) + positive_bias = {(1,): 100.0, (4,): 100.0} + negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0} + sequence_bias = {**positive_bias, **negative_bias} + + # scores = 0 to facilitate checks + scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device) + + bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias) + filtered_scores = bias_dist_proc(input_ids, scores.clone()) + + # batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2) + # batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3) + self.assertListEqual( + filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]] + ) + def test_processor_list(self): batch_size = 4 sequence_length = 10