From 179d02ffb8f3e0f0f3330dd2762286e83cbaa65b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 18 Mar 2025 18:39:36 +0000 Subject: [PATCH] =?UTF-8?q?[generate]=20=E2=9C=A8=20vectorized=20beam=20se?= =?UTF-8?q?arch=20=E2=9C=A8=20(#35802)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/generation/tf_utils.py | 2 +- src/transformers/generation/utils.py | 550 ++++++++++++------ src/transformers/models/rag/modeling_rag.py | 12 +- tests/generation/test_utils.py | 113 ++-- tests/models/cohere2/test_modeling_cohere2.py | 4 - tests/models/gemma2/test_modeling_gemma2.py | 4 - .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 6 - .../models/qwen2_vl/test_modeling_qwen2_vl.py | 6 - 8 files changed, 426 insertions(+), 271 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index dd1f819fe5..98adb32df0 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -2118,7 +2118,7 @@ class TFGenerationMixin: a greedy approach, otherwise does multinomial sampling without replacement. Parameters: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + input_ids (`tf.Tensor` of shape `(batch_size, num_beams, sequence_length)`): The sequence used as a prompt for the generation. do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3761b59da9..3165ad1c77 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2322,29 +2322,16 @@ class GenerationMixin: ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) - - # 12. interleave input_ids with `num_beams` additional sequences per batch + # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) - - # 13. run beam sample + # 12. run beam sample result = self._beam_search( input_ids, - beam_scorer, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, @@ -3396,6 +3383,7 @@ class GenerationMixin: else: return input_ids + # Auxiliary functions for beam search def _temporary_reorder_cache(self, past_key_values, beam_idx): """ Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. @@ -3422,10 +3410,208 @@ class GenerationMixin: past_key_values.reorder_cache(beam_idx) return past_key_values + @staticmethod + def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor: + """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]""" + shape = list(tensor.shape) + return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:]) + + @staticmethod + def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor: + """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]""" + shape = list(tensor.shape) + return torch.reshape(tensor, [batch_size, num_beams] + shape[1:]) + + @staticmethod + def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor: + """ + Gathers the beam slices indexed by beam_indices into new beam array. + + Args: + tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor + with the two first dimensions depicting the batch and the beam dimensions. + beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to + select . + + Returns: + A tensor with the selected beams + """ + # `take_along_dim` requires its indices arg to have the same number of dims as `input` + while len(beam_indices.shape) < len(tensor.shape): + beam_indices = beam_indices.unsqueeze(-1) + gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1) + return gathered_tensor + + @staticmethod + def _beam_search_has_unfinished_sequences( + running_beam_scores: torch.Tensor, + beam_scores: torch.Tensor, + is_sent_finished: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + cur_len: int, + max_length: int, + decoder_prompt_len: int, + early_stopping: Union[bool, str], + length_penalty: float, + ): + """ + Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False + """ + # a. Can the open beams improve the top completed scores? + # early_stopping == False -> apply heuristic = always get the best score from + # `cur_len - decoder_prompt_len`. See the discussion below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the + # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use + # `max_length` there. + if early_stopping == "never" and length_penalty > 0.0: + best_hypothetical_length = max_length - decoder_prompt_len + else: + best_hypothetical_length = cur_len - decoder_prompt_len + best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) + worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) + improvement_possible = torch.any(best_possible_running_score > worst_finished_score) + + # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is + # enabled, where we want to finish as soon as all beams have a completed sequence. + exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True)) + + # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have + # reached `max_length`` + valid_continuations = ~torch.all(next_token_hits_stopping_criteria) + + return improvement_possible & exists_open_beam & valid_continuations + + def _get_top_k_continuations( + self, + accumulated_log_probs: torch.Tensor, + running_sequences: torch.Tensor, + running_beam_indices: torch.Tensor, + cur_len: int, + decoder_prompt_len: int, + do_sample: bool, + beams_to_keep: int, + num_beams: int, + vocab_size: int, + batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Get top-K continuations given the accumulated log probs on the next token. + + A few notes to understand what's going on: + 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the + top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated + log-probabilities, or sample them without replacement using the accumulated scores + 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at + least `num_beams` sequences remaining to continue the live beam search. + 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations + selected in this step hit the stopping criteria. + """ + # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after + # token selection. The function should be an argument exposed, so that custom scoring functions can be + # defined. + + # Gather the top K scores from _all_ beams. + if do_sample: + topk_indices = torch.multinomial( + nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep + ) + topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices) + else: + topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep) + + # Gather K top beams, recover the beam index by floor division and token id by modulo division + topk_current_beam_indices = topk_indices // vocab_size + topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices) + topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices) + topk_ids = topk_indices % vocab_size + + # Update sequences for the K top-k new sequences. + topk_running_sequences[:, :, cur_len] = topk_ids + + # we want to store the beam indices with batch information -> real beam index = beam index % num beams + batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams + batch_modified_indices = topk_current_beam_indices + batch_offset + topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices + + return topk_log_probs, topk_running_sequences, topk_running_beam_indices + + def _get_running_beams_for_next_iteration( + self, + topk_log_probs: torch.Tensor, + topk_running_sequences: torch.Tensor, + topk_running_beam_indices: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + num_beams: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the + best non-finished beams to continue beam search in the next iteration. + """ + # To prevent these just finished sequences from being used in subsequent iterations, set their log probs + # to a very large negative value + topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9 + + next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1] + running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices) + running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices) + running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices) + return running_sequences, running_beam_scores, running_beam_indices + + def _update_finished_beams( + self, + sequences: torch.Tensor, + topk_running_sequences: torch.Tensor, + beam_scores: torch.Tensor, + topk_log_probs: torch.Tensor, + beam_indices: torch.Tensor, + topk_running_beam_indices: torch.Tensor, + is_sent_finished: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + top_num_beam_mask: torch.Tensor, + num_beams: int, + cur_len: int, + decoder_prompt_len: int, + length_penalty: float, + early_stopping: Union[bool, str], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Updates the finished beams if (and only if) there are new completed sequences that have a higher score than + the current finished sequences. + """ + # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the + # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to + # continue. + did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :] + + # Further process topk logits for the finished beams + # - add length penalty + topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty) + # - make sure no scores can be added anymore if beam is full and early stopping is on + beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True) + topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9 + # - make sure still running sequences cannot be chosen as finalized beam + topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 + + # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized + # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score + # in this step), and keep the best `num_beams` sequences. + merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1) + merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1) + merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1) + merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1) + topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1] + sequences = self._gather_beams(merged_sequences, topk_merged_indices) + beam_scores = self._gather_beams(merged_scores, topk_merged_indices) + beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices) + is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices) + return sequences, beam_scores, beam_indices, is_sent_finished + + # end of auxiliary functions for beam search + def _beam_search( self, input_ids: torch.LongTensor, - beam_scorer: BeamScorer, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, @@ -3436,12 +3622,15 @@ class GenerationMixin: Generates sequences of token ids for models with a language modeling head using **beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + If it's the first time you're diving into Beam Search, we recommend you read the following blog post: + https://huggingface.co/blog/how-to-generate (especially the beam search section). + + You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function + (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores) + Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. @@ -3464,7 +3653,8 @@ class GenerationMixin: `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ - # init values + + # 1. init beam_search values pad_token_id = generation_config._pad_token_tensor eos_token_id = generation_config._eos_token_tensor output_attentions = generation_config.output_attentions @@ -3472,26 +3662,51 @@ class GenerationMixin: output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate - sequential = generation_config.low_memory do_sample = generation_config.do_sample + early_stopping = generation_config.early_stopping + length_penalty = generation_config.length_penalty + max_length = generation_config.max_length + num_beams = generation_config.num_beams + num_return_sequences = generation_config.num_return_sequences - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams + batch_size_unflattened, cur_len = input_ids.shape + batch_size = batch_size_unflattened // num_beams + # TODO (joao): standardize special cases + if self.__class__.__name__ == "MoshiDepthDecoder": + vocab_size = self.config.audio_vocab_size + elif self.__class__.__name__ == "ImageGPTForCausalImageModeling": + vocab_size = self.get_output_embeddings().out_features + else: + vocab_size = self.config.get_text_config().vocab_size + decoder_prompt_len = cur_len + this_peer_finished = False + + # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates + # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K + # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences + # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams + top_num_beam_mask = torch.cat( + (torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)), + dim=0, + ).to(input_ids.device) - batch_beam_size, cur_len = input_ids.shape model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - if num_beams * batch_size != batch_beam_size: + # (joao) feature lost in the refactor. Probably won't implement, hurts readbility with minimal gains (there + # are newer low-memory alternatives like the offloaded cache) + sequential = generation_config.low_memory + if sequential: raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in " + "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered." ) - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None + # 2. init output tuples + all_scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None - ) + beam_indices = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None @@ -3503,184 +3718,195 @@ class GenerationMixin: model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) + # 3. init running tensors and static-shaped placeholders + + # per batch, beam-item holding current token in loop and completed sequences + output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1 + running_sequences = torch.full( + (batch_size, num_beams, max_length), + fill_value=output_fill_value, + dtype=torch.int64, + device=input_ids.device, + ) + running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams) + sequences = running_sequences.clone().detach() + + # per batch, beam-item score, logprobs # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) + running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + running_beam_scores[:, 1:] = -1e9 + beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device) - this_peer_finished = False + # per batch, beam-item state bit indicating if sentence has finished. + is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device) - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + # per batch, beam-item state bit indicating if there are valid continuations. + next_token_hits_stopping_criteria = torch.zeros( + (batch_size, num_beams), dtype=torch.bool, device=input_ids.device + ) + # per batch selected beam indices + running_beam_indices = torch.full( + (batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device + ) + beam_indices = running_beam_indices.clone().detach() + + # 4. run the generation loop while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # a. Forward current tokens, obtain the logits + flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len]) + model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - # if sequential is True, split the input to batches of batch_size and run sequentially - if sequential: - if any( - model_name in self.__class__.__name__.lower() - for model_name in [ - "fsmt", - "reformer", - "ctrl", - "gpt_bigcode", - "transo_xl", - "xlnet", - "cpm", - "jamba", - ] - ): - raise RuntimeError( - f"Currently generation for {self.__class__.__name__} is not supported " - f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature." - ) - - inputs_per_sub_batches = _split_model_inputs( - model_inputs, - split_size=batch_size, - full_batch_size=batch_beam_size, - config=self.config.get_text_config(), - ) - outputs_per_sub_batch = [ - self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches - ] - - outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config()) - - else: # Unchanged original behavior - outputs = self(**model_inputs, return_dict=True) + model_outputs = self(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( - outputs, + model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 continue - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - # .float() is needed to retain precision for later logits manipulations - next_token_logits = outputs.logits[:, -1, :].clone().float() - next_token_logits = next_token_logits.to(input_ids.device) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) + logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref + logits = logits.to(input_ids.device) - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) + # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.* + # `temperature`, ...), and add new logprobs to existing running logprobs scores. + log_probs = nn.functional.log_softmax(logits, dim=-1) + log_probs = logits_processor(flat_running_sequences, log_probs) - # Store scores, attentions and hidden_states when required + # Store logits, attentions and hidden_states when required if return_dict_in_generate: - if output_scores: - scores += (next_token_scores_processed,) if output_logits: - raw_logits += (next_token_logits,) + raw_logits += (logits.clone(),) + if return_dict_in_generate and output_scores: + all_scores += (log_probs.clone(),) + if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (model_outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (model_outputs.attentions,) ) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + cross_attentions += (model_outputs.cross_attentions,) + if output_hidden_states: decoder_hidden_states += ( - (outputs.decoder_hidden_states,) + (model_outputs.decoder_hidden_states,) if self.config.is_encoder_decoder - else (outputs.hidden_states,) + else (model_outputs.hidden_states,) ) - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + # This is needed to properly delete logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del model_outputs - # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 - # non eos token per beam. - n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 - n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams - if do_sample: - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - else: - next_token_scores, next_tokens = torch.topk( - next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True - ) + log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams) + log_probs = log_probs + running_beam_scores[:, :, None] + log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size)) - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, + # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best + # continuations among all beams based on the accumulated scores. + topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations( + accumulated_log_probs=log_probs, + running_sequences=running_sequences, + running_beam_indices=running_beam_indices, + cur_len=cur_len, decoder_prompt_len=decoder_prompt_len, + do_sample=do_sample, + beams_to_keep=beams_to_keep, + num_beams=num_beams, + vocab_size=vocab_size, + batch_size=batch_size, ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] + # d. Check which running sequences have finished + next_token_hits_stopping_criteria = stopping_criteria( + self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes + all_scores, + ) + next_token_hits_stopping_criteria = self._unflatten_beam_dim( + next_token_hits_stopping_criteria, batch_size, beams_to_keep + ) - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + # e. Get the non-finished running `num_beams` sequences for the next generation step + running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration( + topk_log_probs=topk_log_probs, + topk_running_sequences=topk_running_sequences, + topk_running_beam_indices=topk_running_beam_indices, + next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, + num_beams=num_beams, + ) - # This is needed to properly delete outputs.logits which may be very large for first iteration - # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration - # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory - # (that way the memory peak does not include outputs.logits) - del outputs + # f. Update the completed beams if a new high score in a finished sequence is found + sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams( + sequences=sequences, + topk_running_sequences=topk_running_sequences, + beam_scores=beam_scores, + topk_log_probs=topk_log_probs, + beam_indices=beam_indices, + topk_running_beam_indices=topk_running_beam_indices, + is_sent_finished=is_sent_finished, + next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, + top_num_beam_mask=top_num_beam_mask, + num_beams=num_beams, + cur_len=cur_len, + decoder_prompt_len=decoder_prompt_len, + length_penalty=length_penalty, + early_stopping=early_stopping, + ) + # g. Prepare remaining data for the next iteration, including computing the stopping condition for + # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`) + + # pluck the cache from the beam indices that will be used in the next iteration if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx + past_key_values=model_kwargs["past_key_values"], + beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]), ) - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len cur_len = cur_len + 1 + this_peer_finished = not self._beam_search_has_unfinished_sequences( + running_beam_scores, + beam_scores, + is_sent_finished, + next_token_hits_stopping_criteria, + cur_len, + max_length, + decoder_prompt_len, + early_stopping, + length_penalty, + ) - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True + # 5. prepare outputs + # Take best beams for each batch (the score is sorted in descending order) + sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :]) + beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences]) + beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :]) - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) + # Crop the static-shaped tensors to the actual size + sequences = sequences[:, :cur_len] + beam_indices = beam_indices[:, : cur_len - decoder_prompt_len] if return_dict_in_generate: if not output_scores: - sequence_outputs["sequence_scores"] = None + beam_scores = None if self.config.is_encoder_decoder: return GenerateBeamEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, + sequences=sequences, + sequences_scores=beam_scores, + scores=all_scores, logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], + beam_indices=beam_indices, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -3690,17 +3916,17 @@ class GenerationMixin: ) else: return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, + sequences=sequences, + sequences_scores=beam_scores, + scores=all_scores, logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], + beam_indices=beam_indices, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: - return sequence_outputs["sequences"] + return sequences def _group_beam_search( self, @@ -3717,7 +3943,7 @@ class GenerationMixin: decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and @@ -4008,7 +4234,7 @@ class GenerationMixin: decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): The sequence used as a prompt for the generation. constrained_beam_scorer (`ConstrainedBeamSearchScorer`): A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index d3ca787691..17a1f44aa1 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -22,7 +22,7 @@ import torch from torch import nn from ...configuration_utils import PretrainedConfig -from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -1563,18 +1563,8 @@ class RagTokenForGeneration(RagPreTrainedModel): elif generation_config.num_beams > 1: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=self.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) return self._beam_search( input_ids, - beam_scorer, logits_processor=pre_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e8e14b0497..c166bbeec1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1099,70 +1099,6 @@ class GenerationTesterMixin: ) self.assertListEqual(low_output.tolist(), high_output.tolist()) - @pytest.mark.generate - def test_beam_search_low_memory(self): - # Check that choosing 'low_memory' does not change the model output - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="May fix in the future: need custom cache handling") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") - if any( - model_name in model_class.__name__.lower() - for model_name in [ - "ctrl", - "gptbigcode", - "transo_xl", - "xlnet", - "cpm", - "jamba", - ] - ): - self.skipTest(reason="May fix in the future: need model-specific fixes") - - set_model_tester_for_less_flaky_test(self) - - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - set_config_for_less_flaky_test(config) - # batch_size=1 is ok, but batch_size>1 will cause non-identical output - - config.use_cache = True - config.is_decoder = True - - # test output equality of low versus high memory - model = model_class(config).to(torch_device).eval() - set_model_for_less_flaky_test(model) - - logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) - - low_output = model.generate( - **inputs_dict, - max_new_tokens=8, - num_beams=5, - early_stopping=True, - low_memory=True, - use_cache=True, - output_scores=True, - output_logits=True, - return_dict_in_generate=True, - **logits_processor_kwargs, - ) - - high_output = model.generate( - **inputs_dict, - max_new_tokens=8, - num_beams=5, - early_stopping=True, - low_memory=False, - use_cache=True, - output_scores=True, - output_logits=True, - return_dict_in_generate=True, - **logits_processor_kwargs, - ) - # The two outputs must match and their shape must be as expected - self._check_similar_generate_outputs(low_output, high_output) - @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate def test_assisted_decoding_matches_greedy_search(self, assistant_type): @@ -2964,19 +2900,6 @@ class GenerationIntegrationTests(unittest.TestCase): torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3) - def test_beam_search_low_memory(self): - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - tokenizer.pad_token_id = tokenizer.eos_token_id - model_inputs = tokenizer("I", return_tensors="pt")["input_ids"] - - low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True) - - high_output = model.generate( - model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False - ) - self.assertListEqual(low_output.tolist(), high_output.tolist()) - @slow def test_green_red_watermark_generation(self): model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) @@ -4311,6 +4234,42 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertEqual(decoded_assisted, [expected_output]) @slow + def test_beam_search_advanced_stopping_criteria(self): + """ + Tests that beam search works with a stopping criteria that is not max length or EOS token. Prior to the beam + search vectorization PR (#35802), beam search was not accepting other stopping criteria. Test inspired on + the original issue (#34843). + """ + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct").to(torch_device) + + prompt = ( + "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. " + "How many clips did Natalia sell altogether in April and May?" + ) + tokens = tokenizer(prompt, return_tensors="pt").to(torch_device) + generation_config = GenerationConfig(num_beams=3, do_sample=False, length_penalty=1.0, max_new_tokens=100) + + # This particular prompt should result in a ":" being present in the answer + out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer) + output_text = tokenizer.decode(out[0], skip_special_tokens=True) + last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1]) + self.assertTrue(":" in output_text) + self.assertFalse(":" in output_text[-5:]) + self.assertFalse(":" in last_non_special_token_decoded) + + # Adding an advanced stopping criteria: text generation should stop when a ":" is generated. + # Note that: + # 1 - the text up to ":" doesn't have to be the same, it can belong to a different beam + # 2 - ":" may not be the last char, but it must be in the last non-special token + generation_config.stop_strings = ":" + out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer) + output_text = tokenizer.decode(out[0], skip_special_tokens=True) + last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1]) + self.assertTrue(":" in output_text) + self.assertTrue(":" in output_text[-5:]) + self.assertTrue(":" in last_non_special_token_decoded) + def test_max_time(self): tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 699eb15ddb..789649b832 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -104,10 +104,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): def test_generate_continue_from_past_key_values(self): pass - @unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation") - def test_beam_search_low_memory(self): - pass - @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation") def test_contrastive_generate(self): pass diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 3a51e0bbf7..cf27d90a9b 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -119,10 +119,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_generate_continue_from_past_key_values(self): pass - @unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation") - def test_beam_search_low_memory(self): - pass - @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") def test_contrastive_generate(self): pass diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index dcb0816a0d..4b5bb61eb8 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -332,12 +332,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test def test_model_is_small(self): pass - @unittest.skip( - reason="Qwen2.5-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that" - ) - def test_beam_search_low_memory(self): - pass - @unittest.skip( reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" ) diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 655effb09d..e1eafbf693 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -344,12 +344,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_model_is_small(self): pass - @unittest.skip( - reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that" - ) - def test_beam_search_low_memory(self): - pass - @unittest.skip( reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs" )