Terminator strings for generate() (#28932)
* stash commit (will discard all of this) * stash commit * First commit - needs a lot of testing! * Add a test * Fix imports and make the tests actually test something * Tests pass! * Rearrange test * Add comments (but it's still a bit confusing) * Stop storing the tokenizer * Comment fixup * Fix for input_ids with a single sequence * Update tests to test single sequences * make fixup * Fix incorrect use of isin() * Expand tests to catch more cases * Expand tests to catch more cases * make fixup * Fix length calculation and update tests * Handle Ġ as a space replacement too * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Add optimizations from Joao's suggestion * Remove TODO * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * make fixup * Rename some variables and remove some debugging clauses for clarity * Add tests for the sub-methods * Clarify one test slightly * Add stop_strings to GenerationConfig * generate() supports stop_string arg, asks for tokenizer if not provided * make fixup * Cleanup code and rename variables for clarity * Update tokenizer error * Update tokenizer passing, handle generation on GPU * Slightly more explanation cleanup * More comment cleanup * Factor out the token cleanup so it's more obvious what we're doing, and we can change it later * Careful with that cleanup! * Cleanup + optimizations to _get_matching_positions * More minor performance tweaks * Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms) * Remove the pin_memory call * Parallelize across all stop strings! * Quick fix for tensor devices * Update embeddings test for the new format * Fix test imports * Manual patching for BERT-like tokenizers * Return a bool vector instead of a single True/False * Better comment * Better comment * Add tests from @zucchini-nlp * Amy's list creation nit * tok_list -> token_list * Push a big expanded docstring (should we put it somewhere else?) * Expand docstrings * Docstring fixups * Rebase * make fixup * Make a properly general method for figuring out token strings * Fix naming throughout the functions * Move cache, refactor, fix tests * Add comment * Remove finished TODO * Remove finished TODO * make fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update and shorten docstring * Update tests to be shorter/clearer and test specific cases --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -86,6 +86,7 @@ else:
|
|||||||
"StoppingCriteria",
|
"StoppingCriteria",
|
||||||
"StoppingCriteriaList",
|
"StoppingCriteriaList",
|
||||||
"validate_stopping_criteria",
|
"validate_stopping_criteria",
|
||||||
|
"StopStringCriteria",
|
||||||
]
|
]
|
||||||
_import_structure["utils"] = [
|
_import_structure["utils"] = [
|
||||||
"GenerationMixin",
|
"GenerationMixin",
|
||||||
@@ -224,6 +225,7 @@ if TYPE_CHECKING:
|
|||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
StopStringCriteria,
|
||||||
validate_stopping_criteria,
|
validate_stopping_criteria,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
|||||||
@@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
max_time(`float`, *optional*):
|
max_time(`float`, *optional*):
|
||||||
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
|
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
|
||||||
the current pass after allocated time has been passed.
|
the current pass after allocated time has been passed.
|
||||||
|
stop_strings(`str or List[str]`, *optional*):
|
||||||
|
A string or a list of strings that should terminate generation if the model outputs them.
|
||||||
|
|
||||||
> Parameters that control the generation strategy used
|
> Parameters that control the generation strategy used
|
||||||
|
|
||||||
@@ -306,6 +308,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
|
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
|
||||||
self.early_stopping = kwargs.pop("early_stopping", False)
|
self.early_stopping = kwargs.pop("early_stopping", False)
|
||||||
self.max_time = kwargs.pop("max_time", None)
|
self.max_time = kwargs.pop("max_time", None)
|
||||||
|
self.stop_strings = kwargs.pop("stop_strings", None)
|
||||||
|
|
||||||
# Parameters that control the generation strategy used
|
# Parameters that control the generation strategy used
|
||||||
self.do_sample = kwargs.pop("do_sample", False)
|
self.do_sample = kwargs.pop("do_sample", False)
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from ..utils import add_start_docstrings, logging
|
from ..utils import add_start_docstrings, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
# We maintain a module-level cache of the embedding vectors for the stop string criterion
|
||||||
|
# because they are slow to compute
|
||||||
|
STOP_STRING_EMBEDDING_CACHE = OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
||||||
@@ -129,6 +136,334 @@ class MaxTimeCriteria(StoppingCriteria):
|
|||||||
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
||||||
|
|
||||||
|
|
||||||
|
class StopStringCriteria(StoppingCriteria):
|
||||||
|
"""
|
||||||
|
This class can be used to stop generation whenever specific string sequences are generated. It preprocesses
|
||||||
|
the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings.
|
||||||
|
|
||||||
|
Generation is stopped as soon as a token is generated that completes any of the stop strings.
|
||||||
|
We want to catch any instance in which the stop string would be present in the decoded output, which means
|
||||||
|
we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string
|
||||||
|
"stop", any of the following token sequences would trigger the match:
|
||||||
|
|
||||||
|
- ["st", "op"]
|
||||||
|
- ["stop"]
|
||||||
|
- ["st", "opera"]
|
||||||
|
- ["sto", "pper"]
|
||||||
|
- ["las", "topper"]
|
||||||
|
- ["s", "to", "pped"]
|
||||||
|
|
||||||
|
Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other
|
||||||
|
words, these sequences will not trigger a match:
|
||||||
|
|
||||||
|
- ["stop", "at"]
|
||||||
|
- ["st", "op", "at"]
|
||||||
|
- ["st", "opera", "tion"]
|
||||||
|
|
||||||
|
The reason these are not a match is that the stop string does not overlap with the final token. If you can remove
|
||||||
|
one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not
|
||||||
|
match that stop string. This is by design; because this check is run after each token is generated, we can't miss a
|
||||||
|
valid stop string if one is generated, but we don't want to halt generation just because the stop string exists
|
||||||
|
somewhere in the past input_ids.
|
||||||
|
|
||||||
|
How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match
|
||||||
|
process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible,
|
||||||
|
with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use
|
||||||
|
with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations.
|
||||||
|
|
||||||
|
The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at
|
||||||
|
the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of
|
||||||
|
the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for
|
||||||
|
some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this
|
||||||
|
property:
|
||||||
|
|
||||||
|
- ["st", "op"] (overlap is "op", overlap length == 2)
|
||||||
|
- ["stop"] (overlap is "stop", overlap length == 4)
|
||||||
|
- ["st", "opera"] (overlap is "op", overlap length == 2)
|
||||||
|
- ["sto", "pper"] (overlap is "p", overlap length == 1)
|
||||||
|
- ["las", "topper"] (overlap is "top", overlap length == 3)
|
||||||
|
- ["s", "to", "pped"] (overlap is "p", overlap length == 1)
|
||||||
|
|
||||||
|
It's impossible to construct a matching sequence that does not have this property (feel free to verify this
|
||||||
|
yourself). However, although this overlap between the start of the final token and the end of the stop string is
|
||||||
|
necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is
|
||||||
|
consistent with the stop string.
|
||||||
|
|
||||||
|
How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an
|
||||||
|
overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already
|
||||||
|
matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to"
|
||||||
|
matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the
|
||||||
|
remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so
|
||||||
|
we have matched the entire stop string.
|
||||||
|
|
||||||
|
How does it work when the tokens run off the start of the stop string, though? Let's consider the example of
|
||||||
|
["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore,
|
||||||
|
the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to
|
||||||
|
match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This
|
||||||
|
matches the stop string, and so the entire string is matched.
|
||||||
|
|
||||||
|
How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary
|
||||||
|
information for all tokens! For every token, we compute:
|
||||||
|
- Its overlap with the end of the stop string, if any
|
||||||
|
- The positions inside the stop string where the token matches, including matches that run off the start.
|
||||||
|
- The total length of the token
|
||||||
|
|
||||||
|
For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions,
|
||||||
|
and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position
|
||||||
|
of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap,
|
||||||
|
a single internal matching position of 3 (again counting from the end) and a length of 1.
|
||||||
|
|
||||||
|
As long as we have this information, we can execute the algorithm above without any string comparison
|
||||||
|
operations. We simply perform the following steps:
|
||||||
|
- Check if the final token has an end-overlap with the start string
|
||||||
|
- Continue backwards, keeping track of how much of the stop string we've matched so far
|
||||||
|
- At each point, check if the next token has the current position as one of its valid positions
|
||||||
|
- Continue until either a match fails, or we completely match the whole stop string
|
||||||
|
|
||||||
|
Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match.
|
||||||
|
We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again,
|
||||||
|
counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched
|
||||||
|
3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length
|
||||||
|
to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the
|
||||||
|
entire stop string.
|
||||||
|
|
||||||
|
In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have
|
||||||
|
matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we
|
||||||
|
allow tokens to match positions that run off the start of the stop string. We add its length to the position
|
||||||
|
tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though -
|
||||||
|
this also counts as a match of the stop string. We have matched the entire stop string.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer (`PreTrainedTokenizer`):
|
||||||
|
The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences)
|
||||||
|
stop_strings (`Union[str, List[str]]`):
|
||||||
|
A list of strings that should end generation. If a string is passed, it will be treated like a
|
||||||
|
list with a single element.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
|
||||||
|
>>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt")
|
||||||
|
|
||||||
|
>>> gen_out = model.generate(**inputs)
|
||||||
|
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
|
||||||
|
The biggest states in the USA by land area:
|
||||||
|
- Alaska
|
||||||
|
- Texas
|
||||||
|
- California
|
||||||
|
|
||||||
|
>>> # Passing one or more stop strings will halt generation after those strings are emitted
|
||||||
|
>>> # Note that generating with stop strings requires you to pass the tokenizer too
|
||||||
|
>>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer)
|
||||||
|
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
|
||||||
|
The biggest states in the USA by land area:
|
||||||
|
- Alaska
|
||||||
|
- Texas
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]):
|
||||||
|
if isinstance(stop_strings, str):
|
||||||
|
stop_strings = [stop_strings]
|
||||||
|
self.stop_strings: Tuple[str, ...] = tuple(stop_strings)
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
|
||||||
|
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
|
||||||
|
token_list, token_indices, self.stop_strings, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
|
||||||
|
self.num_stop_strings = len(self.stop_strings)
|
||||||
|
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
|
||||||
|
|
||||||
|
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
|
||||||
|
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
|
||||||
|
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
|
||||||
|
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
|
||||||
|
(token_list, token_indices, self.stop_strings)
|
||||||
|
]
|
||||||
|
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
|
||||||
|
else:
|
||||||
|
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
|
||||||
|
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
|
||||||
|
clean_token_list, clean_token_indices, stop_strings
|
||||||
|
)
|
||||||
|
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
|
||||||
|
embedding_vec,
|
||||||
|
max_valid_positions,
|
||||||
|
max_valid_end_lens,
|
||||||
|
)
|
||||||
|
if len(STOP_STRING_EMBEDDING_CACHE) > 8:
|
||||||
|
STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item
|
||||||
|
return embedding_vec, max_valid_positions, max_valid_end_lens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
|
||||||
|
"""
|
||||||
|
This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
|
||||||
|
it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
|
||||||
|
tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
|
||||||
|
space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
|
||||||
|
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
|
||||||
|
"""
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
clean_token_list = []
|
||||||
|
clean_token_indices = []
|
||||||
|
sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
|
||||||
|
tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
|
||||||
|
for token, token_idx in vocab.items():
|
||||||
|
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
|
||||||
|
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
|
||||||
|
clean_token_list.append(token_string)
|
||||||
|
clean_token_indices.append(token_idx)
|
||||||
|
return tuple(clean_token_list), tuple(clean_token_indices)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stop_string_get_matching_positions(
|
||||||
|
token_list, token_indices, stop_strings
|
||||||
|
) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]:
|
||||||
|
"""This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
|
||||||
|
validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
|
||||||
|
token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
|
||||||
|
from the end of the stop string that overlap with the start of the token, which can have more than one value.
|
||||||
|
|
||||||
|
The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full
|
||||||
|
explanation of what these values are for!"""
|
||||||
|
|
||||||
|
token_valid_positions = {}
|
||||||
|
token_end_overlaps = {}
|
||||||
|
for stop_string in stop_strings:
|
||||||
|
reversed_stop_string = stop_string[::-1]
|
||||||
|
token_valid_positions[stop_string] = {}
|
||||||
|
token_end_overlaps[stop_string] = {}
|
||||||
|
for token, tok_idx in zip(token_list, token_indices):
|
||||||
|
reversed_token = token[::-1]
|
||||||
|
matching_positions = []
|
||||||
|
possible_end_lengths = []
|
||||||
|
for i in range(1 - len(token), len(stop_string)):
|
||||||
|
if i < 0:
|
||||||
|
tok = reversed_token[-i:]
|
||||||
|
i = 0
|
||||||
|
else:
|
||||||
|
tok = reversed_token
|
||||||
|
stop = reversed_stop_string[i : i + len(tok)]
|
||||||
|
if tok.startswith(stop):
|
||||||
|
if i == 0:
|
||||||
|
possible_end_lengths.append(min(len(tok), len(stop)))
|
||||||
|
else:
|
||||||
|
matching_positions.append(i)
|
||||||
|
|
||||||
|
if matching_positions:
|
||||||
|
token_valid_positions[stop_string][tok_idx] = matching_positions
|
||||||
|
if possible_end_lengths:
|
||||||
|
token_end_overlaps[stop_string][tok_idx] = possible_end_lengths
|
||||||
|
return token_valid_positions, token_end_overlaps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]:
|
||||||
|
"""This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs
|
||||||
|
them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values
|
||||||
|
that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!"""
|
||||||
|
token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
|
||||||
|
token_list, token_indices, stop_strings
|
||||||
|
)
|
||||||
|
|
||||||
|
max_valid_positions = max(
|
||||||
|
len(val) for positions in token_valid_positions.values() for val in positions.values()
|
||||||
|
)
|
||||||
|
max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values())
|
||||||
|
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
|
||||||
|
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
|
||||||
|
|
||||||
|
for i, stop_string in enumerate(stop_strings):
|
||||||
|
positions = token_valid_positions[stop_string]
|
||||||
|
end_lens = token_end_overlaps[stop_string]
|
||||||
|
|
||||||
|
# Since this is lots of very small assignments of lists, we build it with numpy rather
|
||||||
|
# than torch for speed + simplicity, then convert to torch at the end
|
||||||
|
for token_idx, valid_positions in positions.items():
|
||||||
|
gather_vec[
|
||||||
|
token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions)
|
||||||
|
] = valid_positions
|
||||||
|
for token_idx, possible_end_lens in end_lens.items():
|
||||||
|
gather_vec[
|
||||||
|
token_idx,
|
||||||
|
max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions
|
||||||
|
* len(stop_strings)
|
||||||
|
+ max_valid_end_lens * i
|
||||||
|
+ len(possible_end_lens),
|
||||||
|
] = possible_end_lens
|
||||||
|
for token, token_idx in zip(token_list, token_indices):
|
||||||
|
gather_vec[token_idx, -1] = len(token)
|
||||||
|
|
||||||
|
gather_vec = torch.tensor(gather_vec, dtype=torch.int32)
|
||||||
|
|
||||||
|
return gather_vec, max_valid_positions, max_valid_end_lens
|
||||||
|
|
||||||
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor:
|
||||||
|
self.embedding_vec = self.embedding_vec.to(input_ids.device)
|
||||||
|
self.target_lens = self.target_lens.to(input_ids.device)
|
||||||
|
# The maximum length we need to consider is 1 token per character. Note that input_ids can also be
|
||||||
|
# *shorter* than the global max, and the code below should be ready for that
|
||||||
|
input_ids = input_ids[:, -self.maximum_token_len :]
|
||||||
|
|
||||||
|
# Flip input_ids because we're only matching strings at the end of the generated sequence
|
||||||
|
flipped_ids = torch.flip(input_ids, (1,))
|
||||||
|
|
||||||
|
# Size of the vector of positions a single token can match
|
||||||
|
max_valid_positions = self.max_valid_positions
|
||||||
|
|
||||||
|
# The embedding vec contains the valid positions, end_lengths and total lengths for each token
|
||||||
|
embedded = F.embedding(flipped_ids, self.embedding_vec)
|
||||||
|
|
||||||
|
# Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit
|
||||||
|
valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten(
|
||||||
|
-1, (self.num_stop_strings, -1)
|
||||||
|
)
|
||||||
|
# end_lengths is the number of characters from the string, counting from the end, that the token
|
||||||
|
# contains. It can have multiple values if the same token can overlap different end lengths
|
||||||
|
end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten(
|
||||||
|
-1, (self.num_stop_strings, -1)
|
||||||
|
)
|
||||||
|
# Lengths is the total length of each token. Unlike the others, it always has a single value
|
||||||
|
lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const
|
||||||
|
|
||||||
|
# Concatenate lengths onto each possible end_lengths value
|
||||||
|
lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1]))
|
||||||
|
lengths_with_ends = torch.cat([end_lengths, lengths], dim=1)
|
||||||
|
|
||||||
|
# cumsum() to get the number of matched characters in the stop string after each token
|
||||||
|
cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens
|
||||||
|
|
||||||
|
# The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not.
|
||||||
|
# First, tokens match the start of the string if they have a positive value in the end_lengths vector
|
||||||
|
initial_match = end_lengths > 0
|
||||||
|
|
||||||
|
# Tokens continue the string if the cumsum() so far is one of the valid positions for that token
|
||||||
|
# Note that we're actually tracking one cumsum() for for each possible end_length
|
||||||
|
later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
|
||||||
|
|
||||||
|
# The match vector is a boolean vector that indicates which positions have valid tokens
|
||||||
|
match = torch.cat([initial_match, later_match], dim=1)
|
||||||
|
|
||||||
|
# Once a single position does not match, all positions following that position are masked
|
||||||
|
mask = (~match).cumsum(dim=1, dtype=torch.int32)
|
||||||
|
mask = mask == 0
|
||||||
|
|
||||||
|
# The string is matched if we reached a cumsum equal to or greater than the length of the string
|
||||||
|
# before hitting the mask
|
||||||
|
string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :]
|
||||||
|
|
||||||
|
# We return a per-sample vector that is True if any stop string is matched for that sample
|
||||||
|
return torch.any(string_matches, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class EosTokenCriteria(StoppingCriteria):
|
class EosTokenCriteria(StoppingCriteria):
|
||||||
"""
|
"""
|
||||||
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
|
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
|
||||||
|
|||||||
@@ -80,12 +80,14 @@ from .stopping_criteria import (
|
|||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
StopStringCriteria,
|
||||||
validate_stopping_criteria,
|
validate_stopping_criteria,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .streamers import BaseStreamer
|
from .streamers import BaseStreamer
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -885,7 +887,11 @@ class GenerationMixin:
|
|||||||
return processors
|
return processors
|
||||||
|
|
||||||
def _get_stopping_criteria(
|
def _get_stopping_criteria(
|
||||||
self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
|
self,
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList],
|
||||||
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
|
**kwargs,
|
||||||
) -> StoppingCriteriaList:
|
) -> StoppingCriteriaList:
|
||||||
criteria = StoppingCriteriaList()
|
criteria = StoppingCriteriaList()
|
||||||
if generation_config.max_length is not None:
|
if generation_config.max_length is not None:
|
||||||
@@ -898,6 +904,14 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
if generation_config.max_time is not None:
|
if generation_config.max_time is not None:
|
||||||
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
||||||
|
if generation_config.stop_strings is not None:
|
||||||
|
if tokenizer is None:
|
||||||
|
raise ValueError(
|
||||||
|
"There are one or more stop strings, either in the arguments to `generate` or in the "
|
||||||
|
"model's generation config, but we could not locate a tokenizer. When generating with "
|
||||||
|
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
|
||||||
|
)
|
||||||
|
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
|
||||||
if generation_config.eos_token_id is not None:
|
if generation_config.eos_token_id is not None:
|
||||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
||||||
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
||||||
@@ -1380,6 +1394,7 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||||
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
@@ -1389,6 +1404,7 @@ class GenerationMixin:
|
|||||||
synced_gpus = True
|
synced_gpus = True
|
||||||
else:
|
else:
|
||||||
synced_gpus = False
|
synced_gpus = False
|
||||||
|
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
|
||||||
@@ -1531,7 +1547,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 9. prepare stopping criteria
|
# 9. prepare stopping criteria
|
||||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
||||||
)
|
)
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import AutoTokenizer, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
|
|
||||||
from ..test_modeling_common import ids_tensor
|
from ..test_modeling_common import ids_tensor
|
||||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
|||||||
MaxNewTokensCriteria,
|
MaxNewTokensCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
StopStringCriteria,
|
||||||
validate_stopping_criteria,
|
validate_stopping_criteria,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,3 +125,134 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
||||||
|
|
||||||
self.assertEqual(len(stopping_criteria), 1)
|
self.assertEqual(len(stopping_criteria), 1)
|
||||||
|
|
||||||
|
def test_stop_string_criteria(self):
|
||||||
|
true_strings = [
|
||||||
|
"<|im_start|><|im_end|>",
|
||||||
|
"<|im_start|><|im_end|<|im_end|>",
|
||||||
|
">><|im_start|>>stop",
|
||||||
|
"stop",
|
||||||
|
"e nd",
|
||||||
|
]
|
||||||
|
false_strings = [
|
||||||
|
"<|im_start|><|im_end|",
|
||||||
|
"<|im_start|><|im_end|<|im_end|",
|
||||||
|
"<|im_end|><|im_start|>",
|
||||||
|
"<|im_end|<>stop<|im_end|",
|
||||||
|
"end",
|
||||||
|
"en d",
|
||||||
|
"eNd",
|
||||||
|
"<|im_end|",
|
||||||
|
"|im_end|>",
|
||||||
|
"s",
|
||||||
|
]
|
||||||
|
stop_strings = ["<|im_end|>", "stop", "e nd"]
|
||||||
|
|
||||||
|
# Use a tokenizer that won't actually have special tokens for these
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||||
|
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||||
|
|
||||||
|
scores = None
|
||||||
|
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||||
|
for i in range(len(true_strings)):
|
||||||
|
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
|
||||||
|
for i in range(len(false_strings)):
|
||||||
|
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||||
|
|
||||||
|
# Now try it with a tokenizer where those are actually special tokens
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b")
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||||
|
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||||
|
|
||||||
|
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||||
|
for i in range(len(true_strings)):
|
||||||
|
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
|
||||||
|
for i in range(len(false_strings)):
|
||||||
|
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||||
|
|
||||||
|
def test_stop_string_matching_positions(self):
|
||||||
|
stop_string = "stop"
|
||||||
|
token_list = ["last", "top", "topper", "s", "p"]
|
||||||
|
token_indices = list(range(len(token_list)))
|
||||||
|
all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
|
||||||
|
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
|
||||||
|
)
|
||||||
|
valid_positions = {
|
||||||
|
token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items()
|
||||||
|
}
|
||||||
|
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()}
|
||||||
|
self.assertEqual(valid_positions, {"s": [3], "last": [2]})
|
||||||
|
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})
|
||||||
|
|
||||||
|
def test_stop_string_embedding_vecs(self):
|
||||||
|
stop_string = "stop"
|
||||||
|
token_list = ["last", "top", "topper", "s", "p"]
|
||||||
|
token_indices = list(range(len(token_list)))
|
||||||
|
embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec(
|
||||||
|
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Positions inside the stop string where the token matches (excluding end overlaps)
|
||||||
|
valid_positions = embedding_vec[:, 0].tolist()
|
||||||
|
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
|
||||||
|
|
||||||
|
# Overlap lengths between end of stop string and start of token
|
||||||
|
end_overlaps = embedding_vec[:, 1].tolist()
|
||||||
|
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
|
||||||
|
|
||||||
|
# Length of each token
|
||||||
|
token_lengths = embedding_vec[:, 2].tolist()
|
||||||
|
self.assertEqual(token_lengths, [len(token) for token in token_list])
|
||||||
|
|
||||||
|
def test_criterias_per_row(self):
|
||||||
|
text = "They completed the challenging puzzle, revealing the hidden image at the end"
|
||||||
|
stop_strings = ["end"]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
|
||||||
|
|
||||||
|
scores = None
|
||||||
|
criteria = StoppingCriteriaList(
|
||||||
|
[
|
||||||
|
MaxLengthCriteria(max_length=20),
|
||||||
|
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# trigger stopping when at leat one criteria is satisfied, one value per batch
|
||||||
|
self.assertTrue(criteria(inputs["input_ids"], scores))
|
||||||
|
|
||||||
|
# return False when neither is satisfied
|
||||||
|
self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores))
|
||||||
|
|
||||||
|
def test_criterias_per_row_batched(self):
|
||||||
|
text = [
|
||||||
|
"They completed the challenging puzzle, revealing the hidden image at the end",
|
||||||
|
"Today a dragon flew over France",
|
||||||
|
"The aroma of freshly baked pizza filled the kitchen",
|
||||||
|
]
|
||||||
|
stop_strings = ["end"]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||||
|
|
||||||
|
scores = None
|
||||||
|
criteria = StoppingCriteriaList(
|
||||||
|
[
|
||||||
|
MaxLengthCriteria(max_length=20),
|
||||||
|
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# trigger stopping when at leat one criteria is satisfied
|
||||||
|
self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False])
|
||||||
|
|
||||||
|
# False when neither is satisfied
|
||||||
|
self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False])
|
||||||
|
|||||||
@@ -2330,6 +2330,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
self.assertListEqual(outputs, ["Wie alt sind Sie?"])
|
self.assertListEqual(outputs, ["Wie alt sind Sie?"])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_per_row_stopping_criteria(self):
|
||||||
|
text = [
|
||||||
|
"They completed the challenging puzzle, revealing the hidden",
|
||||||
|
"Today a dragon flew over France",
|
||||||
|
"The aroma of freshly baked pizza filled the kitchen",
|
||||||
|
]
|
||||||
|
stop_strings = ["secrets"]
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# normal generation with one stopping criteria
|
||||||
|
out = model.generate(input_ids, max_length=15)
|
||||||
|
out_text = tokenizer.batch_decode(out)
|
||||||
|
expected_out = [
|
||||||
|
"They completed the challenging puzzle, revealing the hidden secrets of the world.\n",
|
||||||
|
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
|
||||||
|
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
|
||||||
|
]
|
||||||
|
self.assertListEqual(out_text, expected_out)
|
||||||
|
|
||||||
|
# generation should stop at "secrets" for first batch only, filling the rest with eos tokens
|
||||||
|
out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer)
|
||||||
|
out_text = tokenizer.batch_decode(out)
|
||||||
|
expected_out = [
|
||||||
|
"They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>",
|
||||||
|
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
|
||||||
|
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
|
||||||
|
]
|
||||||
|
self.assertListEqual(out_text, expected_out)
|
||||||
|
|
||||||
def test_constrained_beam_search_mixin_type_checks(self):
|
def test_constrained_beam_search_mixin_type_checks(self):
|
||||||
# PT-only test: TF doesn't have constrained beam search
|
# PT-only test: TF doesn't have constrained beam search
|
||||||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
|
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||||
|
|||||||
Reference in New Issue
Block a user