fix(generation): stop beam search per-instance when heuristic satisfied (#38778)

* fix(decoding): stop beam search per-instance when heuristic satisfied

Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch.

Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation.

* Add test case GenerationIntegrationTests.test_beam_search_early_stop_heuristic

* Update naming improvement_possibility -> is_early_stop_heuristic_unsatisfied

* Add comments for early stop heuristic

* Update src/transformers/generation/utils.py

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Guang Yang
2025-07-08 03:59:37 -05:00
committed by GitHub
parent 0b0ede8b2b
commit 356fd68109
3 changed files with 109 additions and 30 deletions

View File

@@ -2887,6 +2887,53 @@ class GenerationIntegrationTests(unittest.TestCase):
],
)
@slow
def test_beam_search_early_stop_heuristic(self):
"""Regression test for #38778 (early stopping needs to be tracked at a batch level)"""
EXPECTED_OUTPUT = (
"<|user|>\nWhat is 3+5?\n<|assistant|>\nThe sum of 3 and 5 is 8. \n\nSo, 3 + 5 = 8. \n\n"
"Let's confirm this using Python code:\n\n```python\n# Define the numbers\nnum1 = 3\nnum2 = 5\n\n"
"# Calculate the sum\nresult = num1 + num2\n\n# Print the result\nprint(result)\n```\n"
"```output\n8\n```\nThe sum of 3 and 5 is \\(\\boxed{8}\\)."
)
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left")
generation_config = GenerationConfig(
num_beams=10,
max_new_tokens=256,
length_penalty=2,
)
# batch of 1
question = [{"role": "user", "content": "What is 3+5?"}]
question = tokenizer.apply_chat_template(
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")
outputs = model.generate(**inputs, generation_config=generation_config)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertEqual(responses[0], EXPECTED_OUTPUT)
# batch of 2
question = [{"role": "user", "content": "What is 3+5?"}]
cot_question = [
{
"role": "user",
"content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end.",
}
]
question = tokenizer.apply_chat_template(
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
cot_question = tokenizer.apply_chat_template(
cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")
outputs = model.generate(**inputs, generation_config=generation_config)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertEqual(responses[0], EXPECTED_OUTPUT)
def test_max_length_if_input_embeds(self):
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)

View File

@@ -840,7 +840,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
input_ids = torch.zeros((1, 3), dtype=torch.long)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
strict = version.parse(torch.__version__) < version.parse("2.7.0")
exported_program = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position,