Move eos_token_id to stopping criteria (#29459)
* add eos stopping criteria * minor fix * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * check eos is not None and fix tests * make style and fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * camel case everywhere * call stopping criteria list for candidate ids * make style and fixup * Empty commit * Empty commit to pass flaky test * set max length in PromptLookupCandidateGenerator * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * lets fix this typo in docs * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update PR * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
31c575bcf1
commit
0efcf32351
@@ -82,6 +82,7 @@ else:
|
||||
"MaxNewTokensCriteria",
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
"EosTokenCriteria",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"validate_stopping_criteria",
|
||||
@@ -216,6 +217,7 @@ if TYPE_CHECKING:
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
MaxTimeCriteria,
|
||||
|
||||
@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||
The maximum ngram size to be considered for matching in the prompt
|
||||
num_output_tokens (`int`):
|
||||
The number of tokens to be output as candidate tokens.
|
||||
max_length (`int`):
|
||||
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
|
||||
Defaults to 20, which is the max length used as default in generation config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_output_tokens: int = 10,
|
||||
max_matching_ngram_size: int = None,
|
||||
max_length: int = 20,
|
||||
):
|
||||
self.num_output_tokens = num_output_tokens
|
||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||
self.max_length = max_length
|
||||
|
||||
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
||||
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
||||
@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||
"""
|
||||
input_length = input_ids.size(1)
|
||||
|
||||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
||||
if self.max_length == input_length + 1:
|
||||
return input_ids, None
|
||||
|
||||
chosen_ids = None
|
||||
match_found = False
|
||||
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
||||
@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
||||
for idx in match_indices:
|
||||
start_idx = idx + ngram_size
|
||||
end_idx = start_idx + self.num_output_tokens
|
||||
end_idx = min(end_idx, input_length)
|
||||
end_idx = min(end_idx, input_length, self.max_length)
|
||||
|
||||
if start_idx < end_idx:
|
||||
chosen_ids = input_ids[0, start_idx:end_idx]
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria):
|
||||
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
||||
|
||||
|
||||
class EosTokenCriteria(StoppingCriteria):
|
||||
"""
|
||||
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
|
||||
By default, it uses the `model.generation_config.eos_token_id`.
|
||||
|
||||
Args:
|
||||
eos_token_id (`Union[int, List[int]]`):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, eos_token_id: Union[int, List[int]]):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id = torch.tensor(eos_token_id)
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||
return is_done
|
||||
|
||||
|
||||
class StoppingCriteriaList(list):
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
|
||||
@@ -75,6 +75,7 @@ from .logits_process import (
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteria,
|
||||
@@ -690,6 +691,7 @@ class GenerationMixin:
|
||||
candidate_generator = PromptLookupCandidateGenerator(
|
||||
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
||||
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
else:
|
||||
candidate_generator = AssistedCandidateGenerator(
|
||||
@@ -892,6 +894,8 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_config.max_time is not None:
|
||||
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
||||
if generation_config.eos_token_id is not None:
|
||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
||||
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
||||
return criteria
|
||||
|
||||
@@ -1306,7 +1310,7 @@ class GenerationMixin:
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
|
||||
or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
|
||||
|
||||
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
@@ -1515,7 +1519,6 @@ class GenerationMixin:
|
||||
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1530,7 +1533,6 @@ class GenerationMixin:
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1550,7 +1552,6 @@ class GenerationMixin:
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1579,7 +1580,6 @@ class GenerationMixin:
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1613,7 +1613,6 @@ class GenerationMixin:
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1653,7 +1652,6 @@ class GenerationMixin:
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1687,7 +1685,6 @@ class GenerationMixin:
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1761,7 +1758,6 @@ class GenerationMixin:
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
output_logits=generation_config.output_logits,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
@@ -1916,11 +1912,28 @@ class GenerationMixin:
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
||||
output_attentions = (
|
||||
@@ -2186,12 +2199,6 @@ class GenerationMixin:
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
@@ -2365,10 +2372,27 @@ class GenerationMixin:
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
@@ -2463,12 +2487,6 @@ class GenerationMixin:
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
@@ -2650,10 +2668,27 @@ class GenerationMixin:
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
||||
output_attentions = (
|
||||
@@ -2751,12 +2786,6 @@ class GenerationMixin:
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
@@ -2966,7 +2995,25 @@ class GenerationMixin:
|
||||
if len(stopping_criteria) == 0:
|
||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private and beam scorer refactored
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
@@ -3351,7 +3398,25 @@ class GenerationMixin:
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private and beam scorer refactored
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
@@ -3688,7 +3753,25 @@ class GenerationMixin:
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private and beam scorer refactored
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
@@ -4089,7 +4172,25 @@ class GenerationMixin:
|
||||
if len(stopping_criteria) == 0:
|
||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private and beam scorer refactored
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
@@ -4421,12 +4522,27 @@ class GenerationMixin:
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if eos_token_id is not None and pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
if eos_token_id is not None:
|
||||
logger.warning_once(
|
||||
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||
FutureWarning,
|
||||
)
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
else:
|
||||
# TODO remove when the method is totally private and beam scorer refactored
|
||||
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||
eos_token_id = [
|
||||
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||
]
|
||||
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||
eos_token_id = self.generation_config.eos_token_id
|
||||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
||||
output_attentions = (
|
||||
@@ -4462,9 +4578,6 @@ class GenerationMixin:
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
|
||||
# other auxiliary variables
|
||||
max_len = stopping_criteria[0].max_length
|
||||
|
||||
this_peer_finished = False
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
cur_len = input_ids.shape[-1]
|
||||
@@ -4476,13 +4589,7 @@ class GenerationMixin:
|
||||
candidate_logits = candidate_logits.to(self.device)
|
||||
|
||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||
last_assistant_token_is_eos = (
|
||||
~candidate_input_ids[:, -1]
|
||||
.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
.bool()
|
||||
)
|
||||
is_done_candidate = stopping_criteria(candidate_input_ids, None)
|
||||
|
||||
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
||||
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
|
||||
@@ -4525,15 +4632,13 @@ class GenerationMixin:
|
||||
# 3. Select the accepted tokens. There are two possible cases:
|
||||
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
||||
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
||||
max_matches = max_len - cur_len - 1
|
||||
if do_sample and candidate_logits is not None:
|
||||
valid_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
is_done_candidate,
|
||||
)
|
||||
|
||||
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
||||
@@ -4550,9 +4655,8 @@ class GenerationMixin:
|
||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
||||
|
||||
# Ensure we don't generate beyond max_len or an EOS token
|
||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||
if is_done_candidate and n_matches == candidate_length:
|
||||
n_matches -= 1
|
||||
n_matches = min(n_matches, max_matches)
|
||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||
|
||||
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||
@@ -4625,15 +4729,6 @@ class GenerationMixin:
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
input_ids[:, -1]
|
||||
.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
@@ -4678,8 +4773,7 @@ def _speculative_sampling(
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
is_done_candidate,
|
||||
):
|
||||
"""
|
||||
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
||||
@@ -4704,16 +4798,14 @@ def _speculative_sampling(
|
||||
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
||||
|
||||
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
|
||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||
if is_done_candidate and n_matches == candidate_length:
|
||||
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
|
||||
# due to acceptance on EOS we fix `n_matches`
|
||||
n_matches -= 1
|
||||
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
|
||||
else:
|
||||
n_matches = min(n_matches, max_matches)
|
||||
|
||||
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
||||
gamma = min(candidate_logits.shape[1], max_matches)
|
||||
gamma = candidate_logits.shape[1]
|
||||
p_n_plus_1 = p[:, n_matches, :]
|
||||
if n_matches < gamma:
|
||||
q_n_plus_1 = q[:, n_matches, :]
|
||||
|
||||
@@ -26,6 +26,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
MaxTimeCriteria,
|
||||
@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_eos_token_criteria(self):
|
||||
criteria = EosTokenCriteria(eos_token_id=0)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 0
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:2, -1] = 0
|
||||
input_ids[2, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
|
||||
@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
last_assistant_token_is_eos = False
|
||||
max_matches = 5
|
||||
validated_tokens, n_matches = _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
candidate_logits,
|
||||
candidate_length,
|
||||
new_logits,
|
||||
last_assistant_token_is_eos,
|
||||
max_matches,
|
||||
)
|
||||
self.assertTrue(n_matches.item() == 2)
|
||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||
|
||||
Reference in New Issue
Block a user