[generate] vectorized beam search (#35802)

This commit is contained in:
Joao Gante
2025-03-18 18:39:36 +00:00
committed by GitHub
parent 12f2ebef63
commit 179d02ffb8
8 changed files with 426 additions and 271 deletions

View File

@@ -2118,7 +2118,7 @@ class TFGenerationMixin:
a greedy approach, otherwise does multinomial sampling without replacement. a greedy approach, otherwise does multinomial sampling without replacement.
Parameters: Parameters:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): input_ids (`tf.Tensor` of shape `(batch_size, num_beams, sequence_length)`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
do_sample (`bool`, *optional*, defaults to `False`): do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise. Whether or not to use sampling ; use greedy decoding otherwise.

View File

@@ -2322,29 +2322,16 @@ class GenerationMixin:
) )
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare beam search scorer # 11. interleave input_ids with `num_beams` additional sequences per batch
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
expand_size=generation_config.num_beams, expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs, **model_kwargs,
) )
# 12. run beam sample
# 13. run beam sample
result = self._beam_search( result = self._beam_search(
input_ids, input_ids,
beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config, generation_config=generation_config,
@@ -3396,6 +3383,7 @@ class GenerationMixin:
else: else:
return input_ids return input_ids
# Auxiliary functions for beam search
def _temporary_reorder_cache(self, past_key_values, beam_idx): def _temporary_reorder_cache(self, past_key_values, beam_idx):
""" """
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
@@ -3422,10 +3410,208 @@ class GenerationMixin:
past_key_values.reorder_cache(beam_idx) past_key_values.reorder_cache(beam_idx)
return past_key_values return past_key_values
@staticmethod
def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor:
"""[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
shape = list(tensor.shape)
return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
@staticmethod
def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor:
"""[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
shape = list(tensor.shape)
return torch.reshape(tensor, [batch_size, num_beams] + shape[1:])
@staticmethod
def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor:
"""
Gathers the beam slices indexed by beam_indices into new beam array.
Args:
tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
with the two first dimensions depicting the batch and the beam dimensions.
beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
select .
Returns:
A tensor with the selected beams
"""
# `take_along_dim` requires its indices arg to have the same number of dims as `input`
while len(beam_indices.shape) < len(tensor.shape):
beam_indices = beam_indices.unsqueeze(-1)
gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1)
return gathered_tensor
@staticmethod
def _beam_search_has_unfinished_sequences(
running_beam_scores: torch.Tensor,
beam_scores: torch.Tensor,
is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
cur_len: int,
max_length: int,
decoder_prompt_len: int,
early_stopping: Union[bool, str],
length_penalty: float,
):
"""
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
"""
# a. Can the open beams improve the top completed scores?
# early_stopping == False -> apply heuristic = always get the best score from
# `cur_len - decoder_prompt_len`. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
# `max_length` there.
if early_stopping == "never" and length_penalty > 0.0:
best_hypothetical_length = max_length - decoder_prompt_len
else:
best_hypothetical_length = cur_len - decoder_prompt_len
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
improvement_possible = torch.any(best_possible_running_score > worst_finished_score)
# b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
# enabled, where we want to finish as soon as all beams have a completed sequence.
exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True))
# c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
# reached `max_length``
valid_continuations = ~torch.all(next_token_hits_stopping_criteria)
return improvement_possible & exists_open_beam & valid_continuations
def _get_top_k_continuations(
self,
accumulated_log_probs: torch.Tensor,
running_sequences: torch.Tensor,
running_beam_indices: torch.Tensor,
cur_len: int,
decoder_prompt_len: int,
do_sample: bool,
beams_to_keep: int,
num_beams: int,
vocab_size: int,
batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get top-K continuations given the accumulated log probs on the next token.
A few notes to understand what's going on:
1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
log-probabilities, or sample them without replacement using the accumulated scores
2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
least `num_beams` sequences remaining to continue the live beam search.
3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
selected in this step hit the stopping criteria.
"""
# TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
# token selection. The function should be an argument exposed, so that custom scoring functions can be
# defined.
# Gather the top K scores from _all_ beams.
if do_sample:
topk_indices = torch.multinomial(
nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
)
topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
else:
topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)
# Gather K top beams, recover the beam index by floor division and token id by modulo division
topk_current_beam_indices = topk_indices // vocab_size
topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
topk_ids = topk_indices % vocab_size
# Update sequences for the K top-k new sequences.
topk_running_sequences[:, :, cur_len] = topk_ids
# we want to store the beam indices with batch information -> real beam index = beam index % num beams
batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams
batch_modified_indices = topk_current_beam_indices + batch_offset
topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
return topk_log_probs, topk_running_sequences, topk_running_beam_indices
def _get_running_beams_for_next_iteration(
self,
topk_log_probs: torch.Tensor,
topk_running_sequences: torch.Tensor,
topk_running_beam_indices: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
num_beams: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
best non-finished beams to continue beam search in the next iteration.
"""
# To prevent these just finished sequences from being used in subsequent iterations, set their log probs
# to a very large negative value
topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9
next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1]
running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
return running_sequences, running_beam_scores, running_beam_indices
def _update_finished_beams(
self,
sequences: torch.Tensor,
topk_running_sequences: torch.Tensor,
beam_scores: torch.Tensor,
topk_log_probs: torch.Tensor,
beam_indices: torch.Tensor,
topk_running_beam_indices: torch.Tensor,
is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
top_num_beam_mask: torch.Tensor,
num_beams: int,
cur_len: int,
decoder_prompt_len: int,
length_penalty: float,
early_stopping: Union[bool, str],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
the current finished sequences.
"""
# Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
# remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
# continue.
did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :]
# Further process topk logits for the finished beams
# - add length penalty
topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
# - make sure no scores can be added anymore if beam is full and early stopping is on
beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
# Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
# data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
# in this step), and keep the best `num_beams` sequences.
merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)
topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1]
sequences = self._gather_beams(merged_sequences, topk_merged_indices)
beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
return sequences, beam_scores, beam_indices, is_sent_finished
# end of auxiliary functions for beam search
def _beam_search( def _beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: LogitsProcessorList, logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList, stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig, generation_config: GenerationConfig,
@@ -3436,12 +3622,15 @@ class GenerationMixin:
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
https://huggingface.co/blog/how-to-generate (especially the beam search section).
You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
(https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
Parameters: Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`): logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step. used to modify the prediction scores of the language modeling head applied at each generation step.
@@ -3464,7 +3653,8 @@ class GenerationMixin:
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`. `model.config.is_encoder_decoder=True`.
""" """
# init values
# 1. init beam_search values
pad_token_id = generation_config._pad_token_tensor pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions output_attentions = generation_config.output_attentions
@@ -3472,26 +3662,51 @@ class GenerationMixin:
output_scores = generation_config.output_scores output_scores = generation_config.output_scores
output_logits = generation_config.output_logits output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate return_dict_in_generate = generation_config.return_dict_in_generate
sequential = generation_config.low_memory
do_sample = generation_config.do_sample do_sample = generation_config.do_sample
early_stopping = generation_config.early_stopping
length_penalty = generation_config.length_penalty
max_length = generation_config.max_length
num_beams = generation_config.num_beams
num_return_sequences = generation_config.num_return_sequences
batch_size = len(beam_scorer._beam_hyps) batch_size_unflattened, cur_len = input_ids.shape
num_beams = beam_scorer.num_beams batch_size = batch_size_unflattened // num_beams
# TODO (joao): standardize special cases
if self.__class__.__name__ == "MoshiDepthDecoder":
vocab_size = self.config.audio_vocab_size
elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
vocab_size = self.get_output_embeddings().out_features
else:
vocab_size = self.config.get_text_config().vocab_size
decoder_prompt_len = cur_len
this_peer_finished = False
# At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
# with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
# (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
# non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
top_num_beam_mask = torch.cat(
(torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
dim=0,
).to(input_ids.device)
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
if num_beams * batch_size != batch_beam_size: # (joao) feature lost in the refactor. Probably won't implement, hurts readbility with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
sequential = generation_config.low_memory
if sequential:
raise ValueError( raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
"#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
) )
# init attention / hidden states / scores tuples # 2. init output tuples
scores = () if (return_dict_in_generate and output_scores) else None all_scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None raw_logits = () if (return_dict_in_generate and output_logits) else None
beam_indices = ( beam_indices = () if (return_dict_in_generate and output_logits) else None
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
@@ -3503,184 +3718,195 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
# 3. init running tensors and static-shaped placeholders
# per batch, beam-item holding current token in loop and completed sequences
output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
running_sequences = torch.full(
(batch_size, num_beams, max_length),
fill_value=output_fill_value,
dtype=torch.int64,
device=input_ids.device,
)
running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
sequences = running_sequences.clone().detach()
# per batch, beam-item score, logprobs
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams. # of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 running_beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
this_peer_finished = False # per batch, beam-item state bit indicating if sentence has finished.
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder # per batch, beam-item state bit indicating if there are valid continuations.
next_token_hits_stopping_criteria = torch.zeros(
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
)
# per batch selected beam indices
running_beam_indices = torch.full(
(batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
)
beam_indices = running_beam_indices.clone().detach()
# 4. run the generation loop
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # a. Forward current tokens, obtain the logits
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls) # prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# if sequential is True, split the input to batches of batch_size and run sequentially model_outputs = self(**model_inputs, return_dict=True)
if sequential:
if any(
model_name in self.__class__.__name__.lower()
for model_name in [
"fsmt",
"reformer",
"ctrl",
"gpt_bigcode",
"transo_xl",
"xlnet",
"cpm",
"jamba",
]
):
raise RuntimeError(
f"Currently generation for {self.__class__.__name__} is not supported "
f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
)
inputs_per_sub_batches = _split_model_inputs(
model_inputs,
split_size=batch_size,
full_batch_size=batch_beam_size,
config=self.config.get_text_config(),
)
outputs_per_sub_batch = [
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
]
outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())
else: # Unchanged original behavior
outputs = self(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_outputs,
model_kwargs, model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue continue
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref
# (the clone itself is always small) logits = logits.to(input_ids.device)
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores) # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( # `temperature`, ...), and add new logprobs to existing running logprobs scores.
next_token_scores_processed log_probs = nn.functional.log_softmax(logits, dim=-1)
) log_probs = logits_processor(flat_running_sequences, log_probs)
# Store scores, attentions and hidden_states when required # Store logits, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
if output_logits: if output_logits:
raw_logits += (next_token_logits,) raw_logits += (logits.clone(),)
if return_dict_in_generate and output_scores:
all_scores += (log_probs.clone(),)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (model_outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (model_outputs.attentions,)
) )
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,) cross_attentions += (model_outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
(outputs.decoder_hidden_states,) (model_outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder if self.config.is_encoder_decoder
else (outputs.hidden_states,) else (model_outputs.hidden_states,)
) )
# reshape for beam search # This is needed to properly delete logits which may be very large for first iteration
vocab_size = next_token_scores.shape[-1] # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) del model_outputs
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
# non eos token per beam. log_probs = log_probs + running_beam_scores[:, :, None]
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
else:
next_token_scores, next_tokens = torch.topk(
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
next_tokens = next_tokens % vocab_size # continuations among all beams based on the accumulated scores.
topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
# stateless accumulated_log_probs=log_probs,
beam_outputs = beam_scorer.process( running_sequences=running_sequences,
input_ids, running_beam_indices=running_beam_indices,
next_token_scores, cur_len=cur_len,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len, decoder_prompt_len=decoder_prompt_len,
do_sample=do_sample,
beams_to_keep=beams_to_keep,
num_beams=num_beams,
vocab_size=vocab_size,
batch_size=batch_size,
) )
beam_scores = beam_outputs["next_beam_scores"] # d. Check which running sequences have finished
beam_next_tokens = beam_outputs["next_beam_tokens"] next_token_hits_stopping_criteria = stopping_criteria(
beam_idx = beam_outputs["next_beam_indices"] self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
all_scores,
)
next_token_hits_stopping_criteria = self._unflatten_beam_dim(
next_token_hits_stopping_criteria, batch_size, beams_to_keep
)
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) # e. Get the non-finished running `num_beams` sequences for the next generation step
running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
topk_log_probs=topk_log_probs,
topk_running_sequences=topk_running_sequences,
topk_running_beam_indices=topk_running_beam_indices,
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
num_beams=num_beams,
)
# This is needed to properly delete outputs.logits which may be very large for first iteration # f. Update the completed beams if a new high score in a finished sequence is found
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory sequences=sequences,
# (that way the memory peak does not include outputs.logits) topk_running_sequences=topk_running_sequences,
del outputs beam_scores=beam_scores,
topk_log_probs=topk_log_probs,
beam_indices=beam_indices,
topk_running_beam_indices=topk_running_beam_indices,
is_sent_finished=is_sent_finished,
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
top_num_beam_mask=top_num_beam_mask,
num_beams=num_beams,
cur_len=cur_len,
decoder_prompt_len=decoder_prompt_len,
length_penalty=length_penalty,
early_stopping=early_stopping,
)
# g. Prepare remaining data for the next iteration, including computing the stopping condition for
# beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
# pluck the cache from the beam indices that will be used in the next iteration
if model_kwargs.get("past_key_values", None) is not None: if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx past_key_values=model_kwargs["past_key_values"],
beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]),
) )
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
this_peer_finished = not self._beam_search_has_unfinished_sequences(
running_beam_scores,
beam_scores,
is_sent_finished,
next_token_hits_stopping_criteria,
cur_len,
max_length,
decoder_prompt_len,
early_stopping,
length_penalty,
)
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): # 5. prepare outputs
this_peer_finished = True # Take best beams for each batch (the score is sorted in descending order)
sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
sequence_outputs = beam_scorer.finalize( # Crop the static-shaped tensors to the actual size
input_ids, sequences = sequences[:, :cur_len]
beam_scores, beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None beam_scores = None
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return GenerateBeamEncoderDecoderOutput( return GenerateBeamEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequences,
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=beam_scores,
scores=scores, scores=all_scores,
logits=raw_logits, logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"], beam_indices=beam_indices,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
@@ -3690,17 +3916,17 @@ class GenerationMixin:
) )
else: else:
return GenerateBeamDecoderOnlyOutput( return GenerateBeamDecoderOnlyOutput(
sequences=sequence_outputs["sequences"], sequences=sequences,
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=beam_scores,
scores=scores, scores=all_scores,
logits=raw_logits, logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"], beam_indices=beam_indices,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"), past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return sequence_outputs["sequences"] return sequences
def _group_beam_search( def _group_beam_search(
self, self,
@@ -3717,7 +3943,7 @@ class GenerationMixin:
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`): beam_scorer (`BeamScorer`):
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
@@ -4008,7 +4234,7 @@ class GenerationMixin:
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters: Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
constrained_beam_scorer (`ConstrainedBeamSearchScorer`): constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and

View File

@@ -22,7 +22,7 @@ import torch
from torch import nn from torch import nn
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@@ -1563,18 +1563,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
elif generation_config.num_beams > 1: elif generation_config.num_beams > 1:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=self.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
return self._beam_search( return self._beam_search(
input_ids, input_ids,
beam_scorer,
logits_processor=pre_processor, logits_processor=pre_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config, generation_config=generation_config,

View File

@@ -1099,70 +1099,6 @@ class GenerationTesterMixin:
) )
self.assertListEqual(low_output.tolist(), high_output.tolist()) self.assertListEqual(low_output.tolist(), high_output.tolist())
@pytest.mark.generate
def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="May fix in the future: need custom cache handling")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"ctrl",
"gptbigcode",
"transo_xl",
"xlnet",
"cpm",
"jamba",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
set_model_tester_for_less_flaky_test(self)
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
set_config_for_less_flaky_test(config)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
low_output = model.generate(
**inputs_dict,
max_new_tokens=8,
num_beams=5,
early_stopping=True,
low_memory=True,
use_cache=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
**logits_processor_kwargs,
)
high_output = model.generate(
**inputs_dict,
max_new_tokens=8,
num_beams=5,
early_stopping=True,
low_memory=False,
use_cache=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
**logits_processor_kwargs,
)
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)
@parameterized.expand([("random",), ("same",)]) @parameterized.expand([("random",), ("same",)])
@pytest.mark.generate @pytest.mark.generate
def test_assisted_decoding_matches_greedy_search(self, assistant_type): def test_assisted_decoding_matches_greedy_search(self, assistant_type):
@@ -2964,19 +2900,6 @@ class GenerationIntegrationTests(unittest.TestCase):
torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3) torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3)
def test_beam_search_low_memory(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
high_output = model.generate(
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow @slow
def test_green_red_watermark_generation(self): def test_green_red_watermark_generation(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
@@ -4311,6 +4234,42 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertEqual(decoded_assisted, [expected_output]) self.assertEqual(decoded_assisted, [expected_output])
@slow @slow
def test_beam_search_advanced_stopping_criteria(self):
"""
Tests that beam search works with a stopping criteria that is not max length or EOS token. Prior to the beam
search vectorization PR (#35802), beam search was not accepting other stopping criteria. Test inspired on
the original issue (#34843).
"""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct").to(torch_device)
prompt = (
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. "
"How many clips did Natalia sell altogether in April and May?"
)
tokens = tokenizer(prompt, return_tensors="pt").to(torch_device)
generation_config = GenerationConfig(num_beams=3, do_sample=False, length_penalty=1.0, max_new_tokens=100)
# This particular prompt should result in a ":" being present in the answer
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
self.assertTrue(":" in output_text)
self.assertFalse(":" in output_text[-5:])
self.assertFalse(":" in last_non_special_token_decoded)
# Adding an advanced stopping criteria: text generation should stop when a ":" is generated.
# Note that:
# 1 - the text up to ":" doesn't have to be the same, it can belong to a different beam
# 2 - ":" may not be the last char, but it must be in the last non-special token
generation_config.stop_strings = ":"
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
self.assertTrue(":" in output_text)
self.assertTrue(":" in output_text[-5:])
self.assertTrue(":" in last_non_special_token_decoded)
def test_max_time(self): def test_max_time(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")

View File

@@ -104,10 +104,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_generate_continue_from_past_key_values(self): def test_generate_continue_from_past_key_values(self):
pass pass
@unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation") @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self): def test_contrastive_generate(self):
pass pass

View File

@@ -119,10 +119,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_generate_continue_from_past_key_values(self): def test_generate_continue_from_past_key_values(self):
pass pass
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self): def test_contrastive_generate(self):
pass pass

View File

@@ -332,12 +332,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_model_is_small(self): def test_model_is_small(self):
pass pass
@unittest.skip(
reason="Qwen2.5-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
)
def test_beam_search_low_memory(self):
pass
@unittest.skip( @unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
) )

View File

@@ -344,12 +344,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_model_is_small(self): def test_model_is_small(self):
pass pass
@unittest.skip(
reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
)
def test_beam_search_low_memory(self):
pass
@unittest.skip( @unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs" reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
) )