Move eos_token_id to stopping criteria (#29459)
* add eos stopping criteria * minor fix * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * check eos is not None and fix tests * make style and fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * camel case everywhere * call stopping criteria list for candidate ids * make style and fixup * Empty commit * Empty commit to pass flaky test * set max length in PromptLookupCandidateGenerator * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * lets fix this typo in docs * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update PR * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
31c575bcf1
commit
0efcf32351
@@ -82,6 +82,7 @@ else:
|
|||||||
"MaxNewTokensCriteria",
|
"MaxNewTokensCriteria",
|
||||||
"MaxLengthCriteria",
|
"MaxLengthCriteria",
|
||||||
"MaxTimeCriteria",
|
"MaxTimeCriteria",
|
||||||
|
"EosTokenCriteria",
|
||||||
"StoppingCriteria",
|
"StoppingCriteria",
|
||||||
"StoppingCriteriaList",
|
"StoppingCriteriaList",
|
||||||
"validate_stopping_criteria",
|
"validate_stopping_criteria",
|
||||||
@@ -216,6 +217,7 @@ if TYPE_CHECKING:
|
|||||||
WhisperTimeStampLogitsProcessor,
|
WhisperTimeStampLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
|
EosTokenCriteria,
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
MaxNewTokensCriteria,
|
MaxNewTokensCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
|
|||||||
@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
The maximum ngram size to be considered for matching in the prompt
|
The maximum ngram size to be considered for matching in the prompt
|
||||||
num_output_tokens (`int`):
|
num_output_tokens (`int`):
|
||||||
The number of tokens to be output as candidate tokens.
|
The number of tokens to be output as candidate tokens.
|
||||||
|
max_length (`int`):
|
||||||
|
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
|
||||||
|
Defaults to 20, which is the max length used as default in generation config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_output_tokens: int = 10,
|
num_output_tokens: int = 10,
|
||||||
max_matching_ngram_size: int = None,
|
max_matching_ngram_size: int = None,
|
||||||
|
max_length: int = 20,
|
||||||
):
|
):
|
||||||
self.num_output_tokens = num_output_tokens
|
self.num_output_tokens = num_output_tokens
|
||||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
||||||
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
||||||
@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
"""
|
"""
|
||||||
input_length = input_ids.size(1)
|
input_length = input_ids.size(1)
|
||||||
|
|
||||||
|
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
||||||
|
if self.max_length == input_length + 1:
|
||||||
|
return input_ids, None
|
||||||
|
|
||||||
chosen_ids = None
|
chosen_ids = None
|
||||||
match_found = False
|
match_found = False
|
||||||
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
||||||
@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
for idx in match_indices:
|
for idx in match_indices:
|
||||||
start_idx = idx + ngram_size
|
start_idx = idx + ngram_size
|
||||||
end_idx = start_idx + self.num_output_tokens
|
end_idx = start_idx + self.num_output_tokens
|
||||||
end_idx = min(end_idx, input_length)
|
end_idx = min(end_idx, input_length, self.max_length)
|
||||||
|
|
||||||
if start_idx < end_idx:
|
if start_idx < end_idx:
|
||||||
chosen_ids = input_ids[0, start_idx:end_idx]
|
chosen_ids = input_ids[0, start_idx:end_idx]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria):
|
|||||||
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
|
||||||
|
|
||||||
|
|
||||||
|
class EosTokenCriteria(StoppingCriteria):
|
||||||
|
"""
|
||||||
|
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
|
||||||
|
By default, it uses the `model.generation_config.eos_token_id`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eos_token_id: Union[int, List[int]]):
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
self.eos_token_id = torch.tensor(eos_token_id)
|
||||||
|
|
||||||
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
|
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||||
|
return is_done
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteriaList(list):
|
class StoppingCriteriaList(list):
|
||||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ from .logits_process import (
|
|||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
|
EosTokenCriteria,
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
@@ -690,6 +691,7 @@ class GenerationMixin:
|
|||||||
candidate_generator = PromptLookupCandidateGenerator(
|
candidate_generator = PromptLookupCandidateGenerator(
|
||||||
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
||||||
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
||||||
|
max_length=generation_config.max_length,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
candidate_generator = AssistedCandidateGenerator(
|
candidate_generator = AssistedCandidateGenerator(
|
||||||
@@ -892,6 +894,8 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
if generation_config.max_time is not None:
|
if generation_config.max_time is not None:
|
||||||
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
||||||
|
if generation_config.eos_token_id is not None:
|
||||||
|
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
||||||
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
||||||
return criteria
|
return criteria
|
||||||
|
|
||||||
@@ -1306,7 +1310,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
Return:
|
Return:
|
||||||
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
||||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
|
or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
|
||||||
|
|
||||||
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
||||||
[`~utils.ModelOutput`] types are:
|
[`~utils.ModelOutput`] types are:
|
||||||
@@ -1515,7 +1519,6 @@ class GenerationMixin:
|
|||||||
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1530,7 +1533,6 @@ class GenerationMixin:
|
|||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1550,7 +1552,6 @@ class GenerationMixin:
|
|||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1579,7 +1580,6 @@ class GenerationMixin:
|
|||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1613,7 +1613,6 @@ class GenerationMixin:
|
|||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1653,7 +1652,6 @@ class GenerationMixin:
|
|||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1687,7 +1685,6 @@ class GenerationMixin:
|
|||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1761,7 +1758,6 @@ class GenerationMixin:
|
|||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
output_logits=generation_config.output_logits,
|
output_logits=generation_config.output_logits,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
@@ -1916,11 +1912,28 @@ class GenerationMixin:
|
|||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
@@ -2186,12 +2199,6 @@ class GenerationMixin:
|
|||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
|
||||||
if eos_token_id_tensor is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
this_peer_finished = unfinished_sequences.max() == 0
|
this_peer_finished = unfinished_sequences.max() == 0
|
||||||
@@ -2365,10 +2372,27 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
@@ -2463,12 +2487,6 @@ class GenerationMixin:
|
|||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
|
||||||
if eos_token_id_tensor is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
this_peer_finished = unfinished_sequences.max() == 0
|
this_peer_finished = unfinished_sequences.max() == 0
|
||||||
|
|
||||||
@@ -2650,10 +2668,27 @@ class GenerationMixin:
|
|||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
@@ -2751,12 +2786,6 @@ class GenerationMixin:
|
|||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
|
||||||
if eos_token_id_tensor is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
this_peer_finished = unfinished_sequences.max() == 0
|
this_peer_finished = unfinished_sequences.max() == 0
|
||||||
|
|
||||||
@@ -2966,7 +2995,25 @@ class GenerationMixin:
|
|||||||
if len(stopping_criteria) == 0:
|
if len(stopping_criteria) == 0:
|
||||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private and beam scorer refactored
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
@@ -3351,7 +3398,25 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private and beam scorer refactored
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
@@ -3688,7 +3753,25 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private and beam scorer refactored
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
@@ -4089,7 +4172,25 @@ class GenerationMixin:
|
|||||||
if len(stopping_criteria) == 0:
|
if len(stopping_criteria) == 0:
|
||||||
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private and beam scorer refactored
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
@@ -4421,12 +4522,27 @@ class GenerationMixin:
|
|||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
if eos_token_id is not None:
|
||||||
if eos_token_id is not None and pad_token_id is None:
|
logger.warning_once(
|
||||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
||||||
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
||||||
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
else:
|
||||||
|
# TODO remove when the method is totally private and beam scorer refactored
|
||||||
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
||||||
|
eos_token_id = [
|
||||||
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
||||||
|
]
|
||||||
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
||||||
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
||||||
|
eos_token_id = self.generation_config.eos_token_id
|
||||||
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
||||||
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
@@ -4462,9 +4578,6 @@ class GenerationMixin:
|
|||||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||||
|
|
||||||
# other auxiliary variables
|
|
||||||
max_len = stopping_criteria[0].max_length
|
|
||||||
|
|
||||||
this_peer_finished = False
|
this_peer_finished = False
|
||||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
@@ -4476,13 +4589,7 @@ class GenerationMixin:
|
|||||||
candidate_logits = candidate_logits.to(self.device)
|
candidate_logits = candidate_logits.to(self.device)
|
||||||
|
|
||||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||||
last_assistant_token_is_eos = (
|
is_done_candidate = stopping_criteria(candidate_input_ids, None)
|
||||||
~candidate_input_ids[:, -1]
|
|
||||||
.tile(eos_token_id_tensor.shape[0], 1)
|
|
||||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
||||||
.prod(dim=0)
|
|
||||||
.bool()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
||||||
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
|
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
|
||||||
@@ -4525,15 +4632,13 @@ class GenerationMixin:
|
|||||||
# 3. Select the accepted tokens. There are two possible cases:
|
# 3. Select the accepted tokens. There are two possible cases:
|
||||||
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
||||||
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
||||||
max_matches = max_len - cur_len - 1
|
|
||||||
if do_sample and candidate_logits is not None:
|
if do_sample and candidate_logits is not None:
|
||||||
valid_tokens, n_matches = _speculative_sampling(
|
valid_tokens, n_matches = _speculative_sampling(
|
||||||
candidate_input_ids,
|
candidate_input_ids,
|
||||||
candidate_logits,
|
candidate_logits,
|
||||||
candidate_length,
|
candidate_length,
|
||||||
new_logits,
|
new_logits,
|
||||||
last_assistant_token_is_eos,
|
is_done_candidate,
|
||||||
max_matches,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
||||||
@@ -4550,9 +4655,8 @@ class GenerationMixin:
|
|||||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
||||||
|
|
||||||
# Ensure we don't generate beyond max_len or an EOS token
|
# Ensure we don't generate beyond max_len or an EOS token
|
||||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
if is_done_candidate and n_matches == candidate_length:
|
||||||
n_matches -= 1
|
n_matches -= 1
|
||||||
n_matches = min(n_matches, max_matches)
|
|
||||||
valid_tokens = selected_tokens[:, : n_matches + 1]
|
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||||
|
|
||||||
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||||
@@ -4625,15 +4729,6 @@ class GenerationMixin:
|
|||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
|
||||||
if eos_token_id_tensor is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
|
||||||
input_ids[:, -1]
|
|
||||||
.tile(eos_token_id_tensor.shape[0], 1)
|
|
||||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
||||||
.prod(dim=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
this_peer_finished = unfinished_sequences.max() == 0
|
this_peer_finished = unfinished_sequences.max() == 0
|
||||||
|
|
||||||
@@ -4678,8 +4773,7 @@ def _speculative_sampling(
|
|||||||
candidate_logits,
|
candidate_logits,
|
||||||
candidate_length,
|
candidate_length,
|
||||||
new_logits,
|
new_logits,
|
||||||
last_assistant_token_is_eos,
|
is_done_candidate,
|
||||||
max_matches,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
||||||
@@ -4704,16 +4798,14 @@ def _speculative_sampling(
|
|||||||
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
||||||
|
|
||||||
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
|
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
|
||||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
if is_done_candidate and n_matches == candidate_length:
|
||||||
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
|
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
|
||||||
# due to acceptance on EOS we fix `n_matches`
|
# due to acceptance on EOS we fix `n_matches`
|
||||||
n_matches -= 1
|
n_matches -= 1
|
||||||
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
|
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
|
||||||
else:
|
else:
|
||||||
n_matches = min(n_matches, max_matches)
|
|
||||||
|
|
||||||
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
||||||
gamma = min(candidate_logits.shape[1], max_matches)
|
gamma = candidate_logits.shape[1]
|
||||||
p_n_plus_1 = p[:, n_matches, :]
|
p_n_plus_1 = p[:, n_matches, :]
|
||||||
if n_matches < gamma:
|
if n_matches < gamma:
|
||||||
q_n_plus_1 = q[:, n_matches, :]
|
q_n_plus_1 = q[:, n_matches, :]
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
|
EosTokenCriteria,
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
MaxNewTokensCriteria,
|
MaxNewTokensCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||||
self.assertTrue(all(criteria(input_ids, scores)))
|
self.assertTrue(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
|
def test_eos_token_criteria(self):
|
||||||
|
criteria = EosTokenCriteria(eos_token_id=0)
|
||||||
|
|
||||||
|
input_ids, scores = self._get_tensors(5)
|
||||||
|
input_ids[:, -1] = 0
|
||||||
|
self.assertTrue(all(criteria(input_ids, scores)))
|
||||||
|
|
||||||
|
input_ids, scores = self._get_tensors(5)
|
||||||
|
input_ids[:2, -1] = 0
|
||||||
|
input_ids[2, -1] = 1
|
||||||
|
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
|
||||||
|
|
||||||
|
input_ids, scores = self._get_tensors(5)
|
||||||
|
input_ids[:, -1] = 1
|
||||||
|
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||||
|
|
||||||
def test_validate_stopping_criteria(self):
|
def test_validate_stopping_criteria(self):
|
||||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||||
|
|
||||||
|
|||||||
@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
last_assistant_token_is_eos = False
|
last_assistant_token_is_eos = False
|
||||||
max_matches = 5
|
|
||||||
validated_tokens, n_matches = _speculative_sampling(
|
validated_tokens, n_matches = _speculative_sampling(
|
||||||
candidate_input_ids,
|
candidate_input_ids,
|
||||||
candidate_logits,
|
candidate_logits,
|
||||||
candidate_length,
|
candidate_length,
|
||||||
new_logits,
|
new_logits,
|
||||||
last_assistant_token_is_eos,
|
last_assistant_token_is_eos,
|
||||||
max_matches,
|
|
||||||
)
|
)
|
||||||
self.assertTrue(n_matches.item() == 2)
|
self.assertTrue(n_matches.item() == 2)
|
||||||
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
|
||||||
|
|||||||
Reference in New Issue
Block a user