[generate] ✨ vectorized beam search ✨ (#35802)
This commit is contained in:
@@ -2118,7 +2118,7 @@ class TFGenerationMixin:
|
|||||||
a greedy approach, otherwise does multinomial sampling without replacement.
|
a greedy approach, otherwise does multinomial sampling without replacement.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`tf.Tensor` of shape `(batch_size, num_beams, sequence_length)`):
|
||||||
The sequence used as a prompt for the generation.
|
The sequence used as a prompt for the generation.
|
||||||
do_sample (`bool`, *optional*, defaults to `False`):
|
do_sample (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to use sampling ; use greedy decoding otherwise.
|
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||||
|
|||||||
@@ -2322,29 +2322,16 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
||||||
# 11. prepare beam search scorer
|
# 11. interleave input_ids with `num_beams` additional sequences per batch
|
||||||
beam_scorer = BeamSearchScorer(
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_beams=generation_config.num_beams,
|
|
||||||
device=inputs_tensor.device,
|
|
||||||
length_penalty=generation_config.length_penalty,
|
|
||||||
do_early_stopping=generation_config.early_stopping,
|
|
||||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
|
||||||
max_length=generation_config.max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
expand_size=generation_config.num_beams,
|
expand_size=generation_config.num_beams,
|
||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
# 12. run beam sample
|
||||||
# 13. run beam sample
|
|
||||||
result = self._beam_search(
|
result = self._beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
@@ -3396,6 +3383,7 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
# Auxiliary functions for beam search
|
||||||
def _temporary_reorder_cache(self, past_key_values, beam_idx):
|
def _temporary_reorder_cache(self, past_key_values, beam_idx):
|
||||||
"""
|
"""
|
||||||
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
|
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
|
||||||
@@ -3422,10 +3410,208 @@ class GenerationMixin:
|
|||||||
past_key_values.reorder_cache(beam_idx)
|
past_key_values.reorder_cache(beam_idx)
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
|
||||||
|
shape = list(tensor.shape)
|
||||||
|
return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor:
|
||||||
|
"""[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
|
||||||
|
shape = list(tensor.shape)
|
||||||
|
return torch.reshape(tensor, [batch_size, num_beams] + shape[1:])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Gathers the beam slices indexed by beam_indices into new beam array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
|
||||||
|
with the two first dimensions depicting the batch and the beam dimensions.
|
||||||
|
beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
|
||||||
|
select .
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor with the selected beams
|
||||||
|
"""
|
||||||
|
# `take_along_dim` requires its indices arg to have the same number of dims as `input`
|
||||||
|
while len(beam_indices.shape) < len(tensor.shape):
|
||||||
|
beam_indices = beam_indices.unsqueeze(-1)
|
||||||
|
gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1)
|
||||||
|
return gathered_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _beam_search_has_unfinished_sequences(
|
||||||
|
running_beam_scores: torch.Tensor,
|
||||||
|
beam_scores: torch.Tensor,
|
||||||
|
is_sent_finished: torch.Tensor,
|
||||||
|
next_token_hits_stopping_criteria: torch.Tensor,
|
||||||
|
cur_len: int,
|
||||||
|
max_length: int,
|
||||||
|
decoder_prompt_len: int,
|
||||||
|
early_stopping: Union[bool, str],
|
||||||
|
length_penalty: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
|
||||||
|
"""
|
||||||
|
# a. Can the open beams improve the top completed scores?
|
||||||
|
# early_stopping == False -> apply heuristic = always get the best score from
|
||||||
|
# `cur_len - decoder_prompt_len`. See the discussion below for more details.
|
||||||
|
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||||
|
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
|
||||||
|
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
|
||||||
|
# `max_length` there.
|
||||||
|
if early_stopping == "never" and length_penalty > 0.0:
|
||||||
|
best_hypothetical_length = max_length - decoder_prompt_len
|
||||||
|
else:
|
||||||
|
best_hypothetical_length = cur_len - decoder_prompt_len
|
||||||
|
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
|
||||||
|
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
|
||||||
|
improvement_possible = torch.any(best_possible_running_score > worst_finished_score)
|
||||||
|
|
||||||
|
# b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
|
||||||
|
# enabled, where we want to finish as soon as all beams have a completed sequence.
|
||||||
|
exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True))
|
||||||
|
|
||||||
|
# c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
|
||||||
|
# reached `max_length``
|
||||||
|
valid_continuations = ~torch.all(next_token_hits_stopping_criteria)
|
||||||
|
|
||||||
|
return improvement_possible & exists_open_beam & valid_continuations
|
||||||
|
|
||||||
|
def _get_top_k_continuations(
|
||||||
|
self,
|
||||||
|
accumulated_log_probs: torch.Tensor,
|
||||||
|
running_sequences: torch.Tensor,
|
||||||
|
running_beam_indices: torch.Tensor,
|
||||||
|
cur_len: int,
|
||||||
|
decoder_prompt_len: int,
|
||||||
|
do_sample: bool,
|
||||||
|
beams_to_keep: int,
|
||||||
|
num_beams: int,
|
||||||
|
vocab_size: int,
|
||||||
|
batch_size: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get top-K continuations given the accumulated log probs on the next token.
|
||||||
|
|
||||||
|
A few notes to understand what's going on:
|
||||||
|
1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
|
||||||
|
top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
|
||||||
|
log-probabilities, or sample them without replacement using the accumulated scores
|
||||||
|
2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
|
||||||
|
least `num_beams` sequences remaining to continue the live beam search.
|
||||||
|
3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
|
||||||
|
selected in this step hit the stopping criteria.
|
||||||
|
"""
|
||||||
|
# TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
|
||||||
|
# token selection. The function should be an argument exposed, so that custom scoring functions can be
|
||||||
|
# defined.
|
||||||
|
|
||||||
|
# Gather the top K scores from _all_ beams.
|
||||||
|
if do_sample:
|
||||||
|
topk_indices = torch.multinomial(
|
||||||
|
nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
|
||||||
|
)
|
||||||
|
topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
|
||||||
|
else:
|
||||||
|
topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)
|
||||||
|
|
||||||
|
# Gather K top beams, recover the beam index by floor division and token id by modulo division
|
||||||
|
topk_current_beam_indices = topk_indices // vocab_size
|
||||||
|
topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
|
||||||
|
topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
|
||||||
|
topk_ids = topk_indices % vocab_size
|
||||||
|
|
||||||
|
# Update sequences for the K top-k new sequences.
|
||||||
|
topk_running_sequences[:, :, cur_len] = topk_ids
|
||||||
|
|
||||||
|
# we want to store the beam indices with batch information -> real beam index = beam index % num beams
|
||||||
|
batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams
|
||||||
|
batch_modified_indices = topk_current_beam_indices + batch_offset
|
||||||
|
topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
|
||||||
|
|
||||||
|
return topk_log_probs, topk_running_sequences, topk_running_beam_indices
|
||||||
|
|
||||||
|
def _get_running_beams_for_next_iteration(
|
||||||
|
self,
|
||||||
|
topk_log_probs: torch.Tensor,
|
||||||
|
topk_running_sequences: torch.Tensor,
|
||||||
|
topk_running_beam_indices: torch.Tensor,
|
||||||
|
next_token_hits_stopping_criteria: torch.Tensor,
|
||||||
|
num_beams: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
|
||||||
|
best non-finished beams to continue beam search in the next iteration.
|
||||||
|
"""
|
||||||
|
# To prevent these just finished sequences from being used in subsequent iterations, set their log probs
|
||||||
|
# to a very large negative value
|
||||||
|
topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9
|
||||||
|
|
||||||
|
next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1]
|
||||||
|
running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
|
||||||
|
running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
|
||||||
|
running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
|
||||||
|
return running_sequences, running_beam_scores, running_beam_indices
|
||||||
|
|
||||||
|
def _update_finished_beams(
|
||||||
|
self,
|
||||||
|
sequences: torch.Tensor,
|
||||||
|
topk_running_sequences: torch.Tensor,
|
||||||
|
beam_scores: torch.Tensor,
|
||||||
|
topk_log_probs: torch.Tensor,
|
||||||
|
beam_indices: torch.Tensor,
|
||||||
|
topk_running_beam_indices: torch.Tensor,
|
||||||
|
is_sent_finished: torch.Tensor,
|
||||||
|
next_token_hits_stopping_criteria: torch.Tensor,
|
||||||
|
top_num_beam_mask: torch.Tensor,
|
||||||
|
num_beams: int,
|
||||||
|
cur_len: int,
|
||||||
|
decoder_prompt_len: int,
|
||||||
|
length_penalty: float,
|
||||||
|
early_stopping: Union[bool, str],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
|
||||||
|
the current finished sequences.
|
||||||
|
"""
|
||||||
|
# Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
|
||||||
|
# remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
|
||||||
|
# continue.
|
||||||
|
did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :]
|
||||||
|
|
||||||
|
# Further process topk logits for the finished beams
|
||||||
|
# - add length penalty
|
||||||
|
topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
|
||||||
|
# - make sure no scores can be added anymore if beam is full and early stopping is on
|
||||||
|
beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
|
||||||
|
topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
|
||||||
|
# - make sure still running sequences cannot be chosen as finalized beam
|
||||||
|
topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
|
||||||
|
|
||||||
|
# Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
|
||||||
|
# data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
|
||||||
|
# in this step), and keep the best `num_beams` sequences.
|
||||||
|
merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
|
||||||
|
merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
|
||||||
|
merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
|
||||||
|
merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)
|
||||||
|
topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1]
|
||||||
|
sequences = self._gather_beams(merged_sequences, topk_merged_indices)
|
||||||
|
beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
|
||||||
|
beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
|
||||||
|
is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
|
||||||
|
return sequences, beam_scores, beam_indices, is_sent_finished
|
||||||
|
|
||||||
|
# end of auxiliary functions for beam search
|
||||||
|
|
||||||
def _beam_search(
|
def _beam_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
beam_scorer: BeamScorer,
|
|
||||||
logits_processor: LogitsProcessorList,
|
logits_processor: LogitsProcessorList,
|
||||||
stopping_criteria: StoppingCriteriaList,
|
stopping_criteria: StoppingCriteriaList,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
@@ -3436,12 +3622,15 @@ class GenerationMixin:
|
|||||||
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
|
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
|
||||||
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||||
|
|
||||||
|
If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
|
||||||
|
https://huggingface.co/blog/how-to-generate (especially the beam search section).
|
||||||
|
|
||||||
|
You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
|
||||||
|
(https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
||||||
The sequence used as a prompt for the generation.
|
The sequence used as a prompt for the generation.
|
||||||
beam_scorer (`BeamScorer`):
|
|
||||||
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
|
||||||
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
|
|
||||||
logits_processor (`LogitsProcessorList`):
|
logits_processor (`LogitsProcessorList`):
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||||
@@ -3464,7 +3653,8 @@ class GenerationMixin:
|
|||||||
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
||||||
`model.config.is_encoder_decoder=True`.
|
`model.config.is_encoder_decoder=True`.
|
||||||
"""
|
"""
|
||||||
# init values
|
|
||||||
|
# 1. init beam_search values
|
||||||
pad_token_id = generation_config._pad_token_tensor
|
pad_token_id = generation_config._pad_token_tensor
|
||||||
eos_token_id = generation_config._eos_token_tensor
|
eos_token_id = generation_config._eos_token_tensor
|
||||||
output_attentions = generation_config.output_attentions
|
output_attentions = generation_config.output_attentions
|
||||||
@@ -3472,26 +3662,51 @@ class GenerationMixin:
|
|||||||
output_scores = generation_config.output_scores
|
output_scores = generation_config.output_scores
|
||||||
output_logits = generation_config.output_logits
|
output_logits = generation_config.output_logits
|
||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
sequential = generation_config.low_memory
|
|
||||||
do_sample = generation_config.do_sample
|
do_sample = generation_config.do_sample
|
||||||
|
early_stopping = generation_config.early_stopping
|
||||||
|
length_penalty = generation_config.length_penalty
|
||||||
|
max_length = generation_config.max_length
|
||||||
|
num_beams = generation_config.num_beams
|
||||||
|
num_return_sequences = generation_config.num_return_sequences
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
batch_size_unflattened, cur_len = input_ids.shape
|
||||||
num_beams = beam_scorer.num_beams
|
batch_size = batch_size_unflattened // num_beams
|
||||||
|
# TODO (joao): standardize special cases
|
||||||
|
if self.__class__.__name__ == "MoshiDepthDecoder":
|
||||||
|
vocab_size = self.config.audio_vocab_size
|
||||||
|
elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
|
||||||
|
vocab_size = self.get_output_embeddings().out_features
|
||||||
|
else:
|
||||||
|
vocab_size = self.config.get_text_config().vocab_size
|
||||||
|
decoder_prompt_len = cur_len
|
||||||
|
this_peer_finished = False
|
||||||
|
|
||||||
|
# At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
|
||||||
|
# with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
|
||||||
|
# (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
|
||||||
|
# non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
|
||||||
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
||||||
|
beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
|
||||||
|
top_num_beam_mask = torch.cat(
|
||||||
|
(torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
|
||||||
|
dim=0,
|
||||||
|
).to(input_ids.device)
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
|
||||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
|
|
||||||
if num_beams * batch_size != batch_beam_size:
|
# (joao) feature lost in the refactor. Probably won't implement, hurts readbility with minimal gains (there
|
||||||
|
# are newer low-memory alternatives like the offloaded cache)
|
||||||
|
sequential = generation_config.low_memory
|
||||||
|
if sequential:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
"`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
|
||||||
|
"#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
|
||||||
)
|
)
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# 2. init output tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
all_scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
||||||
beam_indices = (
|
beam_indices = () if (return_dict_in_generate and output_logits) else None
|
||||||
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
|
||||||
)
|
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
@@ -3503,184 +3718,195 @@ class GenerationMixin:
|
|||||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 3. init running tensors and static-shaped placeholders
|
||||||
|
|
||||||
|
# per batch, beam-item holding current token in loop and completed sequences
|
||||||
|
output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
|
||||||
|
running_sequences = torch.full(
|
||||||
|
(batch_size, num_beams, max_length),
|
||||||
|
fill_value=output_fill_value,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
|
||||||
|
sequences = running_sequences.clone().detach()
|
||||||
|
|
||||||
|
# per batch, beam-item score, logprobs
|
||||||
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
||||||
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
|
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
beam_scores[:, 1:] = -1e9
|
running_beam_scores[:, 1:] = -1e9
|
||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
|
||||||
|
|
||||||
this_peer_finished = False
|
# per batch, beam-item state bit indicating if sentence has finished.
|
||||||
|
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
|
||||||
|
|
||||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
# per batch, beam-item state bit indicating if there are valid continuations.
|
||||||
|
next_token_hits_stopping_criteria = torch.zeros(
|
||||||
|
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# per batch selected beam indices
|
||||||
|
running_beam_indices = torch.full(
|
||||||
|
(batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
|
||||||
|
)
|
||||||
|
beam_indices = running_beam_indices.clone().detach()
|
||||||
|
|
||||||
|
# 4. run the generation loop
|
||||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
# a. Forward current tokens, obtain the logits
|
||||||
|
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
|
||||||
|
|
||||||
# prepare variable output controls (note: some models won't accept all output controls)
|
# prepare variable output controls (note: some models won't accept all output controls)
|
||||||
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
||||||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
||||||
|
|
||||||
# if sequential is True, split the input to batches of batch_size and run sequentially
|
model_outputs = self(**model_inputs, return_dict=True)
|
||||||
if sequential:
|
|
||||||
if any(
|
|
||||||
model_name in self.__class__.__name__.lower()
|
|
||||||
for model_name in [
|
|
||||||
"fsmt",
|
|
||||||
"reformer",
|
|
||||||
"ctrl",
|
|
||||||
"gpt_bigcode",
|
|
||||||
"transo_xl",
|
|
||||||
"xlnet",
|
|
||||||
"cpm",
|
|
||||||
"jamba",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Currently generation for {self.__class__.__name__} is not supported "
|
|
||||||
f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_per_sub_batches = _split_model_inputs(
|
|
||||||
model_inputs,
|
|
||||||
split_size=batch_size,
|
|
||||||
full_batch_size=batch_beam_size,
|
|
||||||
config=self.config.get_text_config(),
|
|
||||||
)
|
|
||||||
outputs_per_sub_batch = [
|
|
||||||
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())
|
|
||||||
|
|
||||||
else: # Unchanged original behavior
|
|
||||||
outputs = self(**model_inputs, return_dict=True)
|
|
||||||
|
|
||||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs,
|
model_outputs,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
if synced_gpus and this_peer_finished:
|
if synced_gpus and this_peer_finished:
|
||||||
cur_len = cur_len + 1
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref
|
||||||
# (the clone itself is always small)
|
logits = logits.to(input_ids.device)
|
||||||
# .float() is needed to retain precision for later logits manipulations
|
|
||||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
|
||||||
next_token_logits = next_token_logits.to(input_ids.device)
|
|
||||||
next_token_scores = nn.functional.log_softmax(
|
|
||||||
next_token_logits, dim=-1
|
|
||||||
) # (batch_size * num_beams, vocab_size)
|
|
||||||
|
|
||||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
|
||||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
|
||||||
next_token_scores_processed
|
log_probs = nn.functional.log_softmax(logits, dim=-1)
|
||||||
)
|
log_probs = logits_processor(flat_running_sequences, log_probs)
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
# Store logits, attentions and hidden_states when required
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if output_scores:
|
|
||||||
scores += (next_token_scores_processed,)
|
|
||||||
if output_logits:
|
if output_logits:
|
||||||
raw_logits += (next_token_logits,)
|
raw_logits += (logits.clone(),)
|
||||||
|
if return_dict_in_generate and output_scores:
|
||||||
|
all_scores += (log_probs.clone(),)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(model_outputs.decoder_attentions,)
|
||||||
|
if self.config.is_encoder_decoder
|
||||||
|
else (model_outputs.attentions,)
|
||||||
)
|
)
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
cross_attentions += (outputs.cross_attentions,)
|
cross_attentions += (model_outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
(outputs.decoder_hidden_states,)
|
(model_outputs.decoder_hidden_states,)
|
||||||
if self.config.is_encoder_decoder
|
if self.config.is_encoder_decoder
|
||||||
else (outputs.hidden_states,)
|
else (model_outputs.hidden_states,)
|
||||||
)
|
)
|
||||||
|
|
||||||
# reshape for beam search
|
# This is needed to properly delete logits which may be very large for first iteration
|
||||||
vocab_size = next_token_scores.shape[-1]
|
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
||||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
del model_outputs
|
||||||
|
|
||||||
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
|
log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||||
# non eos token per beam.
|
log_probs = log_probs + running_beam_scores[:, :, None]
|
||||||
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
|
||||||
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
|
|
||||||
if do_sample:
|
|
||||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
|
|
||||||
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
|
||||||
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
|
||||||
next_tokens = torch.gather(next_tokens, -1, _indices)
|
|
||||||
else:
|
|
||||||
next_token_scores, next_tokens = torch.topk(
|
|
||||||
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
|
|
||||||
)
|
|
||||||
|
|
||||||
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
# c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
|
||||||
next_tokens = next_tokens % vocab_size
|
# continuations among all beams based on the accumulated scores.
|
||||||
|
topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
|
||||||
# stateless
|
accumulated_log_probs=log_probs,
|
||||||
beam_outputs = beam_scorer.process(
|
running_sequences=running_sequences,
|
||||||
input_ids,
|
running_beam_indices=running_beam_indices,
|
||||||
next_token_scores,
|
cur_len=cur_len,
|
||||||
next_tokens,
|
|
||||||
next_indices,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
beam_indices=beam_indices,
|
|
||||||
decoder_prompt_len=decoder_prompt_len,
|
decoder_prompt_len=decoder_prompt_len,
|
||||||
|
do_sample=do_sample,
|
||||||
|
beams_to_keep=beams_to_keep,
|
||||||
|
num_beams=num_beams,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
beam_scores = beam_outputs["next_beam_scores"]
|
# d. Check which running sequences have finished
|
||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
next_token_hits_stopping_criteria = stopping_criteria(
|
||||||
beam_idx = beam_outputs["next_beam_indices"]
|
self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
|
||||||
|
all_scores,
|
||||||
|
)
|
||||||
|
next_token_hits_stopping_criteria = self._unflatten_beam_dim(
|
||||||
|
next_token_hits_stopping_criteria, batch_size, beams_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
# e. Get the non-finished running `num_beams` sequences for the next generation step
|
||||||
|
running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
|
||||||
|
topk_log_probs=topk_log_probs,
|
||||||
|
topk_running_sequences=topk_running_sequences,
|
||||||
|
topk_running_beam_indices=topk_running_beam_indices,
|
||||||
|
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
|
||||||
|
num_beams=num_beams,
|
||||||
|
)
|
||||||
|
|
||||||
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
# f. Update the completed beams if a new high score in a finished sequence is found
|
||||||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
|
||||||
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
sequences=sequences,
|
||||||
# (that way the memory peak does not include outputs.logits)
|
topk_running_sequences=topk_running_sequences,
|
||||||
del outputs
|
beam_scores=beam_scores,
|
||||||
|
topk_log_probs=topk_log_probs,
|
||||||
|
beam_indices=beam_indices,
|
||||||
|
topk_running_beam_indices=topk_running_beam_indices,
|
||||||
|
is_sent_finished=is_sent_finished,
|
||||||
|
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
|
||||||
|
top_num_beam_mask=top_num_beam_mask,
|
||||||
|
num_beams=num_beams,
|
||||||
|
cur_len=cur_len,
|
||||||
|
decoder_prompt_len=decoder_prompt_len,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
early_stopping=early_stopping,
|
||||||
|
)
|
||||||
|
|
||||||
|
# g. Prepare remaining data for the next iteration, including computing the stopping condition for
|
||||||
|
# beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
|
||||||
|
|
||||||
|
# pluck the cache from the beam indices that will be used in the next iteration
|
||||||
if model_kwargs.get("past_key_values", None) is not None:
|
if model_kwargs.get("past_key_values", None) is not None:
|
||||||
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||||
model_kwargs["past_key_values"], beam_idx
|
past_key_values=model_kwargs["past_key_values"],
|
||||||
|
beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]),
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate and output_scores:
|
|
||||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
|
||||||
|
|
||||||
# increase cur_len
|
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
this_peer_finished = not self._beam_search_has_unfinished_sequences(
|
||||||
|
running_beam_scores,
|
||||||
|
beam_scores,
|
||||||
|
is_sent_finished,
|
||||||
|
next_token_hits_stopping_criteria,
|
||||||
|
cur_len,
|
||||||
|
max_length,
|
||||||
|
decoder_prompt_len,
|
||||||
|
early_stopping,
|
||||||
|
length_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
# 5. prepare outputs
|
||||||
this_peer_finished = True
|
# Take best beams for each batch (the score is sorted in descending order)
|
||||||
|
sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
|
||||||
|
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
|
||||||
|
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
|
||||||
|
|
||||||
sequence_outputs = beam_scorer.finalize(
|
# Crop the static-shaped tensors to the actual size
|
||||||
input_ids,
|
sequences = sequences[:, :cur_len]
|
||||||
beam_scores,
|
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
|
||||||
next_tokens,
|
|
||||||
next_indices,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
max_length=stopping_criteria.max_length,
|
|
||||||
beam_indices=beam_indices,
|
|
||||||
decoder_prompt_len=decoder_prompt_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
beam_scores = None
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return GenerateBeamEncoderDecoderOutput(
|
return GenerateBeamEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequences,
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=beam_scores,
|
||||||
scores=scores,
|
scores=all_scores,
|
||||||
logits=raw_logits,
|
logits=raw_logits,
|
||||||
beam_indices=sequence_outputs["beam_indices"],
|
beam_indices=beam_indices,
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
@@ -3690,17 +3916,17 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return GenerateBeamDecoderOnlyOutput(
|
return GenerateBeamDecoderOnlyOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequences,
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=beam_scores,
|
||||||
scores=scores,
|
scores=all_scores,
|
||||||
logits=raw_logits,
|
logits=raw_logits,
|
||||||
beam_indices=sequence_outputs["beam_indices"],
|
beam_indices=beam_indices,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
past_key_values=model_kwargs.get("past_key_values"),
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequences
|
||||||
|
|
||||||
def _group_beam_search(
|
def _group_beam_search(
|
||||||
self,
|
self,
|
||||||
@@ -3717,7 +3943,7 @@ class GenerationMixin:
|
|||||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
||||||
The sequence used as a prompt for the generation.
|
The sequence used as a prompt for the generation.
|
||||||
beam_scorer (`BeamScorer`):
|
beam_scorer (`BeamScorer`):
|
||||||
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
||||||
@@ -4008,7 +4234,7 @@ class GenerationMixin:
|
|||||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
||||||
The sequence used as a prompt for the generation.
|
The sequence used as a prompt for the generation.
|
||||||
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
|
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
|
||||||
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
from ...modeling_outputs import ModelOutput
|
from ...modeling_outputs import ModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
@@ -1563,18 +1563,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
elif generation_config.num_beams > 1:
|
elif generation_config.num_beams > 1:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
beam_scorer = BeamSearchScorer(
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_beams=generation_config.num_beams,
|
|
||||||
device=self.device,
|
|
||||||
length_penalty=generation_config.length_penalty,
|
|
||||||
do_early_stopping=generation_config.early_stopping,
|
|
||||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
|
||||||
max_length=generation_config.max_length,
|
|
||||||
)
|
|
||||||
return self._beam_search(
|
return self._beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
|
||||||
logits_processor=pre_processor,
|
logits_processor=pre_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
|
|||||||
@@ -1099,70 +1099,6 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
def test_beam_search_low_memory(self):
|
|
||||||
# Check that choosing 'low_memory' does not change the model output
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if model_class._is_stateful:
|
|
||||||
self.skipTest(reason="May fix in the future: need custom cache handling")
|
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
|
||||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
|
||||||
if any(
|
|
||||||
model_name in model_class.__name__.lower()
|
|
||||||
for model_name in [
|
|
||||||
"ctrl",
|
|
||||||
"gptbigcode",
|
|
||||||
"transo_xl",
|
|
||||||
"xlnet",
|
|
||||||
"cpm",
|
|
||||||
"jamba",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
|
||||||
|
|
||||||
set_model_tester_for_less_flaky_test(self)
|
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
|
||||||
set_config_for_less_flaky_test(config)
|
|
||||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
|
|
||||||
# test output equality of low versus high memory
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
set_model_for_less_flaky_test(model)
|
|
||||||
|
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
|
||||||
|
|
||||||
low_output = model.generate(
|
|
||||||
**inputs_dict,
|
|
||||||
max_new_tokens=8,
|
|
||||||
num_beams=5,
|
|
||||||
early_stopping=True,
|
|
||||||
low_memory=True,
|
|
||||||
use_cache=True,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**logits_processor_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
high_output = model.generate(
|
|
||||||
**inputs_dict,
|
|
||||||
max_new_tokens=8,
|
|
||||||
num_beams=5,
|
|
||||||
early_stopping=True,
|
|
||||||
low_memory=False,
|
|
||||||
use_cache=True,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**logits_processor_kwargs,
|
|
||||||
)
|
|
||||||
# The two outputs must match and their shape must be as expected
|
|
||||||
self._check_similar_generate_outputs(low_output, high_output)
|
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
@@ -2964,19 +2900,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
def test_beam_search_low_memory(self):
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
||||||
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
|
|
||||||
|
|
||||||
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
|
|
||||||
|
|
||||||
high_output = model.generate(
|
|
||||||
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
|
|
||||||
)
|
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_green_red_watermark_generation(self):
|
def test_green_red_watermark_generation(self):
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
@@ -4311,6 +4234,42 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertEqual(decoded_assisted, [expected_output])
|
self.assertEqual(decoded_assisted, [expected_output])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
def test_beam_search_advanced_stopping_criteria(self):
|
||||||
|
"""
|
||||||
|
Tests that beam search works with a stopping criteria that is not max length or EOS token. Prior to the beam
|
||||||
|
search vectorization PR (#35802), beam search was not accepting other stopping criteria. Test inspired on
|
||||||
|
the original issue (#34843).
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct").to(torch_device)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. "
|
||||||
|
"How many clips did Natalia sell altogether in April and May?"
|
||||||
|
)
|
||||||
|
tokens = tokenizer(prompt, return_tensors="pt").to(torch_device)
|
||||||
|
generation_config = GenerationConfig(num_beams=3, do_sample=False, length_penalty=1.0, max_new_tokens=100)
|
||||||
|
|
||||||
|
# This particular prompt should result in a ":" being present in the answer
|
||||||
|
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
|
||||||
|
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
||||||
|
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
|
||||||
|
self.assertTrue(":" in output_text)
|
||||||
|
self.assertFalse(":" in output_text[-5:])
|
||||||
|
self.assertFalse(":" in last_non_special_token_decoded)
|
||||||
|
|
||||||
|
# Adding an advanced stopping criteria: text generation should stop when a ":" is generated.
|
||||||
|
# Note that:
|
||||||
|
# 1 - the text up to ":" doesn't have to be the same, it can belong to a different beam
|
||||||
|
# 2 - ":" may not be the last char, but it must be in the last non-special token
|
||||||
|
generation_config.stop_strings = ":"
|
||||||
|
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
|
||||||
|
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
||||||
|
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
|
||||||
|
self.assertTrue(":" in output_text)
|
||||||
|
self.assertTrue(":" in output_text[-5:])
|
||||||
|
self.assertTrue(":" in last_non_special_token_decoded)
|
||||||
|
|
||||||
def test_max_time(self):
|
def test_max_time(self):
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||||||
|
|||||||
@@ -104,10 +104,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||||||
def test_generate_continue_from_past_key_values(self):
|
def test_generate_continue_from_past_key_values(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation")
|
|
||||||
def test_beam_search_low_memory(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
|
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
|
||||||
def test_contrastive_generate(self):
|
def test_contrastive_generate(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -119,10 +119,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
def test_generate_continue_from_past_key_values(self):
|
def test_generate_continue_from_past_key_values(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
|
|
||||||
def test_beam_search_low_memory(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||||
def test_contrastive_generate(self):
|
def test_contrastive_generate(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -332,12 +332,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
def test_model_is_small(self):
|
def test_model_is_small(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
reason="Qwen2.5-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
|
|
||||||
)
|
|
||||||
def test_beam_search_low_memory(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
|
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -344,12 +344,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
def test_model_is_small(self):
|
def test_model_is_small(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
|
|
||||||
)
|
|
||||||
def test_beam_search_low_memory(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
|
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user