fix(generation): stop beam search per-instance when heuristic satisfied (#38778)

* fix(decoding): stop beam search per-instance when heuristic satisfied

Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch.

Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation.

* Add test case GenerationIntegrationTests.test_beam_search_early_stop_heuristic

* Update naming improvement_possibility -> is_early_stop_heuristic_unsatisfied

* Add comments for early stop heuristic

* Update src/transformers/generation/utils.py

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Guang Yang
2025-07-08 03:59:37 -05:00
committed by GitHub
parent 0b0ede8b2b
commit 356fd68109
3 changed files with 109 additions and 30 deletions

View File

@@ -3764,11 +3764,11 @@ class GenerationMixin(ContinuousMixin):
return gathered_tensor return gathered_tensor
@staticmethod @staticmethod
def _beam_search_has_unfinished_sequences( def _check_early_stop_heuristic(
is_early_stop_heuristic_unsatisfied: torch.Tensor,
running_beam_scores: torch.Tensor, running_beam_scores: torch.Tensor,
beam_scores: torch.Tensor, beam_scores: torch.Tensor,
is_sent_finished: torch.Tensor, is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
cur_len: int, cur_len: int,
max_length: int, max_length: int,
decoder_prompt_len: int, decoder_prompt_len: int,
@@ -3776,34 +3776,52 @@ class GenerationMixin(ContinuousMixin):
length_penalty: float, length_penalty: float,
): ):
""" """
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False Determine whether early stopping is possible by checking if the best possible score of running beams
could still improve upon the finished ones.
Mechanism:
- Without a length penalty, beam scores typically decrease as more tokens are generated.
So, if the *best possible* score from any running beam is already worse than the *worst* finished beam,
we can safely stop early.
- With a length penalty, scores may increase with longer sequences. In this case, we use heuristics
to estimate the best possible score — though this estimate may not always be correct — and stop
if no further improvement seems likely.
We apply different heuristics depending on the value of `early_stopping`:
1. `early_stopping == False`:
-> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length.
-> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
2. `early_stopping == "never"`:
-> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`.
-> A positive length penalty favors longer sequences, so we use `max_length` in that case.
NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and
`length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce
better sequences (prior to 2022), and changing it is BC breaking.
""" """
# a. Can the open beams improve the top completed scores?
# early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`.
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
# `max_length` there.
# !!
# Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization
# does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences
# compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut
# generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite
# its name. These modifications were empirically found in the past (prior to 2022) to produce better quality
# generations, and changing them is BC breaking.
# For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`.
# See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# !!
if early_stopping == "never" and length_penalty > 0.0: if early_stopping == "never" and length_penalty > 0.0:
best_hypothetical_length = max_length - decoder_prompt_len best_hypothetical_length = max_length - decoder_prompt_len
else: else:
best_hypothetical_length = cur_len - decoder_prompt_len best_hypothetical_length = cur_len - decoder_prompt_len
# best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before
# applying length penalty
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
improvement_possible = torch.any(best_possible_running_score > worst_finished_score) return is_early_stop_heuristic_unsatisfied & torch.any(
best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
)
@staticmethod
def _beam_search_has_unfinished_sequences(
is_early_stop_heuristic_unsatisfied: torch.Tensor,
is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor,
early_stopping: Union[bool, str],
):
"""
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
"""
# a. Can the open beams improve the top completed scores?
improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied)
# b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
# enabled, where we want to finish as soon as all beams have a completed sequence. # enabled, where we want to finish as soon as all beams have a completed sequence.
@@ -3899,6 +3917,7 @@ class GenerationMixin(ContinuousMixin):
topk_log_probs: torch.Tensor, topk_log_probs: torch.Tensor,
beam_indices: torch.Tensor, beam_indices: torch.Tensor,
topk_running_beam_indices: torch.Tensor, topk_running_beam_indices: torch.Tensor,
is_early_stop_heuristic_unsatisfied: torch.Tensor,
is_sent_finished: torch.Tensor, is_sent_finished: torch.Tensor,
next_token_hits_stopping_criteria: torch.Tensor, next_token_hits_stopping_criteria: torch.Tensor,
top_num_beam_mask: torch.Tensor, top_num_beam_mask: torch.Tensor,
@@ -3923,6 +3942,9 @@ class GenerationMixin(ContinuousMixin):
# - make sure no scores can be added anymore if beam is full and early stopping is on # - make sure no scores can be added anymore if beam is full and early stopping is on
beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True) beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9 topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
# - make sure no scores can be added anymore if improvement is not possible
topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9
# - make sure still running sequences cannot be chosen as finalized beam # - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
@@ -4074,6 +4096,9 @@ class GenerationMixin(ContinuousMixin):
# per batch, beam-item state bit indicating if sentence has finished. # per batch, beam-item state bit indicating if sentence has finished.
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device) is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
# per batch state bit indicating if there is a possibility to improve the best finished sentence.
is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
# per batch, beam-item state bit indicating if there are valid continuations. # per batch, beam-item state bit indicating if there are valid continuations.
next_token_hits_stopping_criteria = torch.zeros( next_token_hits_stopping_criteria = torch.zeros(
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device (batch_size, num_beams), dtype=torch.bool, device=input_ids.device
@@ -4186,6 +4211,7 @@ class GenerationMixin(ContinuousMixin):
topk_log_probs=topk_log_probs, topk_log_probs=topk_log_probs,
beam_indices=beam_indices, beam_indices=beam_indices,
topk_running_beam_indices=topk_running_beam_indices, topk_running_beam_indices=topk_running_beam_indices,
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
is_sent_finished=is_sent_finished, is_sent_finished=is_sent_finished,
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
top_num_beam_mask=top_num_beam_mask, top_num_beam_mask=top_num_beam_mask,
@@ -4207,16 +4233,22 @@ class GenerationMixin(ContinuousMixin):
) )
cur_len = cur_len + 1 cur_len = cur_len + 1
is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
running_beam_scores=running_beam_scores,
beam_scores=beam_scores,
is_sent_finished=is_sent_finished,
cur_len=cur_len,
max_length=max_length,
decoder_prompt_len=decoder_prompt_len,
early_stopping=early_stopping,
length_penalty=length_penalty,
)
this_peer_finished = not self._beam_search_has_unfinished_sequences( this_peer_finished = not self._beam_search_has_unfinished_sequences(
running_beam_scores, is_early_stop_heuristic_unsatisfied,
beam_scores,
is_sent_finished, is_sent_finished,
next_token_hits_stopping_criteria, next_token_hits_stopping_criteria,
cur_len,
max_length,
decoder_prompt_len,
early_stopping, early_stopping,
length_penalty,
) )
# 5. prepare outputs # 5. prepare outputs

View File

@@ -2887,6 +2887,53 @@ class GenerationIntegrationTests(unittest.TestCase):
], ],
) )
@slow
def test_beam_search_early_stop_heuristic(self):
"""Regression test for #38778 (early stopping needs to be tracked at a batch level)"""
EXPECTED_OUTPUT = (
"<|user|>\nWhat is 3+5?\n<|assistant|>\nThe sum of 3 and 5 is 8. \n\nSo, 3 + 5 = 8. \n\n"
"Let's confirm this using Python code:\n\n```python\n# Define the numbers\nnum1 = 3\nnum2 = 5\n\n"
"# Calculate the sum\nresult = num1 + num2\n\n# Print the result\nprint(result)\n```\n"
"```output\n8\n```\nThe sum of 3 and 5 is \\(\\boxed{8}\\)."
)
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left")
generation_config = GenerationConfig(
num_beams=10,
max_new_tokens=256,
length_penalty=2,
)
# batch of 1
question = [{"role": "user", "content": "What is 3+5?"}]
question = tokenizer.apply_chat_template(
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")
outputs = model.generate(**inputs, generation_config=generation_config)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertEqual(responses[0], EXPECTED_OUTPUT)
# batch of 2
question = [{"role": "user", "content": "What is 3+5?"}]
cot_question = [
{
"role": "user",
"content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end.",
}
]
question = tokenizer.apply_chat_template(
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
cot_question = tokenizer.apply_chat_template(
cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")
outputs = model.generate(**inputs, generation_config=generation_config)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertEqual(responses[0], EXPECTED_OUTPUT)
def test_max_length_if_input_embeds(self): def test_max_length_if_input_embeds(self):
article = "Today a dragon flew over Paris." article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)

View File

@@ -840,7 +840,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
input_ids = torch.zeros((1, 3), dtype=torch.long) input_ids = torch.zeros((1, 3), dtype=torch.long)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long) cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}} dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0") strict = version.parse(torch.__version__) < version.parse("2.7.0")
exported_program = exportable_module.export( exported_program = exportable_module.export(
input_ids=input_ids, input_ids=input_ids,
cache_position=cache_position, cache_position=cache_position,