From 356fd681098f4b33a6f95660a5a0252eae313348 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Tue, 8 Jul 2025 03:59:37 -0500 Subject: [PATCH] fix(generation): stop beam search per-instance when heuristic satisfied (#38778) * fix(decoding): stop beam search per-instance when heuristic satisfied Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch. Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation. * Add test case GenerationIntegrationTests.test_beam_search_early_stop_heuristic * Update naming improvement_possibility -> is_early_stop_heuristic_unsatisfied * Add comments for early stop heuristic * Update src/transformers/generation/utils.py --------- Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 90 +++++++++++++++++++--------- tests/generation/test_utils.py | 47 +++++++++++++++ tests/utils/test_cache_utils.py | 2 +- 3 files changed, 109 insertions(+), 30 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e36417269d..c778e9e012 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3764,11 +3764,11 @@ class GenerationMixin(ContinuousMixin): return gathered_tensor @staticmethod - def _beam_search_has_unfinished_sequences( + def _check_early_stop_heuristic( + is_early_stop_heuristic_unsatisfied: torch.Tensor, running_beam_scores: torch.Tensor, beam_scores: torch.Tensor, is_sent_finished: torch.Tensor, - next_token_hits_stopping_criteria: torch.Tensor, cur_len: int, max_length: int, decoder_prompt_len: int, @@ -3776,34 +3776,52 @@ class GenerationMixin(ContinuousMixin): length_penalty: float, ): """ - Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False + Determine whether early stopping is possible by checking if the best possible score of running beams + could still improve upon the finished ones. + + Mechanism: + - Without a length penalty, beam scores typically decrease as more tokens are generated. + So, if the *best possible* score from any running beam is already worse than the *worst* finished beam, + we can safely stop early. + - With a length penalty, scores may increase with longer sequences. In this case, we use heuristics + to estimate the best possible score — though this estimate may not always be correct — and stop + if no further improvement seems likely. + + We apply different heuristics depending on the value of `early_stopping`: + 1. `early_stopping == False`: + -> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length. + -> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + + 2. `early_stopping == "never"`: + -> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`. + -> A positive length penalty favors longer sequences, so we use `max_length` in that case. + + NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and + `length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce + better sequences (prior to 2022), and changing it is BC breaking. """ - # a. Can the open beams improve the top completed scores? - # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. - # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the - # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use - # `max_length` there. - # !! - # Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization - # does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences - # compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut - # generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite - # its name. These modifications were empirically found in the past (prior to 2022) to produce better quality - # generations, and changing them is BC breaking. - # For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`. - # See the discussion below for more details. - # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 - # !! if early_stopping == "never" and length_penalty > 0.0: best_hypothetical_length = max_length - decoder_prompt_len else: best_hypothetical_length = cur_len - decoder_prompt_len - - # best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before - # applying length penalty best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) - improvement_possible = torch.any(best_possible_running_score > worst_finished_score) + return is_early_stop_heuristic_unsatisfied & torch.any( + best_possible_running_score > worst_finished_score, dim=-1, keepdim=True + ) + + @staticmethod + def _beam_search_has_unfinished_sequences( + is_early_stop_heuristic_unsatisfied: torch.Tensor, + is_sent_finished: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + early_stopping: Union[bool, str], + ): + """ + Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False + """ + # a. Can the open beams improve the top completed scores? + improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied) # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is # enabled, where we want to finish as soon as all beams have a completed sequence. @@ -3899,6 +3917,7 @@ class GenerationMixin(ContinuousMixin): topk_log_probs: torch.Tensor, beam_indices: torch.Tensor, topk_running_beam_indices: torch.Tensor, + is_early_stop_heuristic_unsatisfied: torch.Tensor, is_sent_finished: torch.Tensor, next_token_hits_stopping_criteria: torch.Tensor, top_num_beam_mask: torch.Tensor, @@ -3923,6 +3942,9 @@ class GenerationMixin(ContinuousMixin): # - make sure no scores can be added anymore if beam is full and early stopping is on beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True) topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9 + # - make sure no scores can be added anymore if improvement is not possible + topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9 + # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 @@ -4074,6 +4096,9 @@ class GenerationMixin(ContinuousMixin): # per batch, beam-item state bit indicating if sentence has finished. is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device) + # per batch state bit indicating if there is a possibility to improve the best finished sentence. + is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device) + # per batch, beam-item state bit indicating if there are valid continuations. next_token_hits_stopping_criteria = torch.zeros( (batch_size, num_beams), dtype=torch.bool, device=input_ids.device @@ -4186,6 +4211,7 @@ class GenerationMixin(ContinuousMixin): topk_log_probs=topk_log_probs, beam_indices=beam_indices, topk_running_beam_indices=topk_running_beam_indices, + is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied, is_sent_finished=is_sent_finished, next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, top_num_beam_mask=top_num_beam_mask, @@ -4207,16 +4233,22 @@ class GenerationMixin(ContinuousMixin): ) cur_len = cur_len + 1 + is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic( + is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied, + running_beam_scores=running_beam_scores, + beam_scores=beam_scores, + is_sent_finished=is_sent_finished, + cur_len=cur_len, + max_length=max_length, + decoder_prompt_len=decoder_prompt_len, + early_stopping=early_stopping, + length_penalty=length_penalty, + ) this_peer_finished = not self._beam_search_has_unfinished_sequences( - running_beam_scores, - beam_scores, + is_early_stop_heuristic_unsatisfied, is_sent_finished, next_token_hits_stopping_criteria, - cur_len, - max_length, - decoder_prompt_len, early_stopping, - length_penalty, ) # 5. prepare outputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c6fd59f68b..9edb1fe99f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2887,6 +2887,53 @@ class GenerationIntegrationTests(unittest.TestCase): ], ) + @slow + def test_beam_search_early_stop_heuristic(self): + """Regression test for #38778 (early stopping needs to be tracked at a batch level)""" + EXPECTED_OUTPUT = ( + "<|user|>\nWhat is 3+5?\n<|assistant|>\nThe sum of 3 and 5 is 8. \n\nSo, 3 + 5 = 8. \n\n" + "Let's confirm this using Python code:\n\n```python\n# Define the numbers\nnum1 = 3\nnum2 = 5\n\n" + "# Calculate the sum\nresult = num1 + num2\n\n# Print the result\nprint(result)\n```\n" + "```output\n8\n```\nThe sum of 3 and 5 is \\(\\boxed{8}\\)." + ) + + model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left") + generation_config = GenerationConfig( + num_beams=10, + max_new_tokens=256, + length_penalty=2, + ) + # batch of 1 + question = [{"role": "user", "content": "What is 3+5?"}] + question = tokenizer.apply_chat_template( + question, tokenize=False, add_generation_prompt=True, return_tensors="pt" + ) + inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda") + outputs = model.generate(**inputs, generation_config=generation_config) + responses = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertEqual(responses[0], EXPECTED_OUTPUT) + + # batch of 2 + question = [{"role": "user", "content": "What is 3+5?"}] + cot_question = [ + { + "role": "user", + "content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end.", + } + ] + question = tokenizer.apply_chat_template( + question, tokenize=False, add_generation_prompt=True, return_tensors="pt" + ) + cot_question = tokenizer.apply_chat_template( + cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt" + ) + inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda") + + outputs = model.generate(**inputs, generation_config=generation_config) + responses = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertEqual(responses[0], EXPECTED_OUTPUT) + def test_max_length_if_input_embeds(self): article = "Today a dragon flew over Paris." model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 4c6b352168..0c3c90768d 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -840,7 +840,7 @@ class CacheExportIntegrationTest(unittest.TestCase): input_ids = torch.zeros((1, 3), dtype=torch.long) cache_position = torch.tensor([0, 1, 2], dtype=torch.long) dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}} - strict = version.parse(torch.__version__) != version.parse("2.7.0") + strict = version.parse(torch.__version__) < version.parse("2.7.0") exported_program = exportable_module.export( input_ids=input_ids, cache_position=cache_position,