fix(generation): stop beam search per-instance when heuristic satisfied (#38778)
* fix(decoding): stop beam search per-instance when heuristic satisfied Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch. Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation. * Add test case GenerationIntegrationTests.test_beam_search_early_stop_heuristic * Update naming improvement_possibility -> is_early_stop_heuristic_unsatisfied * Add comments for early stop heuristic * Update src/transformers/generation/utils.py --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -3764,11 +3764,11 @@ class GenerationMixin(ContinuousMixin):
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def _beam_search_has_unfinished_sequences(
|
||||
def _check_early_stop_heuristic(
|
||||
is_early_stop_heuristic_unsatisfied: torch.Tensor,
|
||||
running_beam_scores: torch.Tensor,
|
||||
beam_scores: torch.Tensor,
|
||||
is_sent_finished: torch.Tensor,
|
||||
next_token_hits_stopping_criteria: torch.Tensor,
|
||||
cur_len: int,
|
||||
max_length: int,
|
||||
decoder_prompt_len: int,
|
||||
@@ -3776,34 +3776,52 @@ class GenerationMixin(ContinuousMixin):
|
||||
length_penalty: float,
|
||||
):
|
||||
"""
|
||||
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
|
||||
Determine whether early stopping is possible by checking if the best possible score of running beams
|
||||
could still improve upon the finished ones.
|
||||
|
||||
Mechanism:
|
||||
- Without a length penalty, beam scores typically decrease as more tokens are generated.
|
||||
So, if the *best possible* score from any running beam is already worse than the *worst* finished beam,
|
||||
we can safely stop early.
|
||||
- With a length penalty, scores may increase with longer sequences. In this case, we use heuristics
|
||||
to estimate the best possible score — though this estimate may not always be correct — and stop
|
||||
if no further improvement seems likely.
|
||||
|
||||
We apply different heuristics depending on the value of `early_stopping`:
|
||||
1. `early_stopping == False`:
|
||||
-> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length.
|
||||
-> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
|
||||
2. `early_stopping == "never"`:
|
||||
-> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`.
|
||||
-> A positive length penalty favors longer sequences, so we use `max_length` in that case.
|
||||
|
||||
NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and
|
||||
`length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce
|
||||
better sequences (prior to 2022), and changing it is BC breaking.
|
||||
"""
|
||||
# a. Can the open beams improve the top completed scores?
|
||||
# early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`.
|
||||
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
|
||||
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
|
||||
# `max_length` there.
|
||||
# !!
|
||||
# Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization
|
||||
# does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences
|
||||
# compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut
|
||||
# generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite
|
||||
# its name. These modifications were empirically found in the past (prior to 2022) to produce better quality
|
||||
# generations, and changing them is BC breaking.
|
||||
# For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`.
|
||||
# See the discussion below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
# !!
|
||||
if early_stopping == "never" and length_penalty > 0.0:
|
||||
best_hypothetical_length = max_length - decoder_prompt_len
|
||||
else:
|
||||
best_hypothetical_length = cur_len - decoder_prompt_len
|
||||
|
||||
# best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before
|
||||
# applying length penalty
|
||||
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
|
||||
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
|
||||
improvement_possible = torch.any(best_possible_running_score > worst_finished_score)
|
||||
return is_early_stop_heuristic_unsatisfied & torch.any(
|
||||
best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _beam_search_has_unfinished_sequences(
|
||||
is_early_stop_heuristic_unsatisfied: torch.Tensor,
|
||||
is_sent_finished: torch.Tensor,
|
||||
next_token_hits_stopping_criteria: torch.Tensor,
|
||||
early_stopping: Union[bool, str],
|
||||
):
|
||||
"""
|
||||
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
|
||||
"""
|
||||
# a. Can the open beams improve the top completed scores?
|
||||
improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied)
|
||||
|
||||
# b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
|
||||
# enabled, where we want to finish as soon as all beams have a completed sequence.
|
||||
@@ -3899,6 +3917,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
topk_log_probs: torch.Tensor,
|
||||
beam_indices: torch.Tensor,
|
||||
topk_running_beam_indices: torch.Tensor,
|
||||
is_early_stop_heuristic_unsatisfied: torch.Tensor,
|
||||
is_sent_finished: torch.Tensor,
|
||||
next_token_hits_stopping_criteria: torch.Tensor,
|
||||
top_num_beam_mask: torch.Tensor,
|
||||
@@ -3923,6 +3942,9 @@ class GenerationMixin(ContinuousMixin):
|
||||
# - make sure no scores can be added anymore if beam is full and early stopping is on
|
||||
beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
|
||||
topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
|
||||
# - make sure no scores can be added anymore if improvement is not possible
|
||||
topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9
|
||||
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
|
||||
|
||||
@@ -4074,6 +4096,9 @@ class GenerationMixin(ContinuousMixin):
|
||||
# per batch, beam-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
|
||||
|
||||
# per batch state bit indicating if there is a possibility to improve the best finished sentence.
|
||||
is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
|
||||
|
||||
# per batch, beam-item state bit indicating if there are valid continuations.
|
||||
next_token_hits_stopping_criteria = torch.zeros(
|
||||
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
|
||||
@@ -4186,6 +4211,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
topk_log_probs=topk_log_probs,
|
||||
beam_indices=beam_indices,
|
||||
topk_running_beam_indices=topk_running_beam_indices,
|
||||
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
|
||||
is_sent_finished=is_sent_finished,
|
||||
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
|
||||
top_num_beam_mask=top_num_beam_mask,
|
||||
@@ -4207,16 +4233,22 @@ class GenerationMixin(ContinuousMixin):
|
||||
)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
|
||||
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
|
||||
running_beam_scores=running_beam_scores,
|
||||
beam_scores=beam_scores,
|
||||
is_sent_finished=is_sent_finished,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
early_stopping=early_stopping,
|
||||
length_penalty=length_penalty,
|
||||
)
|
||||
this_peer_finished = not self._beam_search_has_unfinished_sequences(
|
||||
running_beam_scores,
|
||||
beam_scores,
|
||||
is_early_stop_heuristic_unsatisfied,
|
||||
is_sent_finished,
|
||||
next_token_hits_stopping_criteria,
|
||||
cur_len,
|
||||
max_length,
|
||||
decoder_prompt_len,
|
||||
early_stopping,
|
||||
length_penalty,
|
||||
)
|
||||
|
||||
# 5. prepare outputs
|
||||
|
||||
@@ -2887,6 +2887,53 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_beam_search_early_stop_heuristic(self):
|
||||
"""Regression test for #38778 (early stopping needs to be tracked at a batch level)"""
|
||||
EXPECTED_OUTPUT = (
|
||||
"<|user|>\nWhat is 3+5?\n<|assistant|>\nThe sum of 3 and 5 is 8. \n\nSo, 3 + 5 = 8. \n\n"
|
||||
"Let's confirm this using Python code:\n\n```python\n# Define the numbers\nnum1 = 3\nnum2 = 5\n\n"
|
||||
"# Calculate the sum\nresult = num1 + num2\n\n# Print the result\nprint(result)\n```\n"
|
||||
"```output\n8\n```\nThe sum of 3 and 5 is \\(\\boxed{8}\\)."
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left")
|
||||
generation_config = GenerationConfig(
|
||||
num_beams=10,
|
||||
max_new_tokens=256,
|
||||
length_penalty=2,
|
||||
)
|
||||
# batch of 1
|
||||
question = [{"role": "user", "content": "What is 3+5?"}]
|
||||
question = tokenizer.apply_chat_template(
|
||||
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
self.assertEqual(responses[0], EXPECTED_OUTPUT)
|
||||
|
||||
# batch of 2
|
||||
question = [{"role": "user", "content": "What is 3+5?"}]
|
||||
cot_question = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end.",
|
||||
}
|
||||
]
|
||||
question = tokenizer.apply_chat_template(
|
||||
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
cot_question = tokenizer.apply_chat_template(
|
||||
cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")
|
||||
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
self.assertEqual(responses[0], EXPECTED_OUTPUT)
|
||||
|
||||
def test_max_length_if_input_embeds(self):
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
@@ -840,7 +840,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
||||
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
||||
strict = version.parse(torch.__version__) < version.parse("2.7.0")
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=input_ids,
|
||||
cache_position=cache_position,
|
||||
|
||||
Reference in New Issue
Block a user