fix(generation): stop beam search per-instance when heuristic satisfied (#38778)
* fix(decoding): stop beam search per-instance when heuristic satisfied Previously, when early_stopping is set to `False`, the early-stopping heuristic only halted generation when **all** batch instances reached the criterion. This caused instances that are impossible (suggested by the heuristic) to improve keep generating, leading to inconsistent and overlong outputs across the batch. Now we apply the heuristic **per-instance**: once a certain instance of batch has its all beams impossibe to improve, we mark that instance finished while letting others continue. This restores expected behavior and ensures consistency in batched generation. * Add test case GenerationIntegrationTests.test_beam_search_early_stop_heuristic * Update naming improvement_possibility -> is_early_stop_heuristic_unsatisfied * Add comments for early stop heuristic * Update src/transformers/generation/utils.py --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -2887,6 +2887,53 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_beam_search_early_stop_heuristic(self):
|
||||
"""Regression test for #38778 (early stopping needs to be tracked at a batch level)"""
|
||||
EXPECTED_OUTPUT = (
|
||||
"<|user|>\nWhat is 3+5?\n<|assistant|>\nThe sum of 3 and 5 is 8. \n\nSo, 3 + 5 = 8. \n\n"
|
||||
"Let's confirm this using Python code:\n\n```python\n# Define the numbers\nnum1 = 3\nnum2 = 5\n\n"
|
||||
"# Calculate the sum\nresult = num1 + num2\n\n# Print the result\nprint(result)\n```\n"
|
||||
"```output\n8\n```\nThe sum of 3 and 5 is \\(\\boxed{8}\\)."
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left")
|
||||
generation_config = GenerationConfig(
|
||||
num_beams=10,
|
||||
max_new_tokens=256,
|
||||
length_penalty=2,
|
||||
)
|
||||
# batch of 1
|
||||
question = [{"role": "user", "content": "What is 3+5?"}]
|
||||
question = tokenizer.apply_chat_template(
|
||||
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
self.assertEqual(responses[0], EXPECTED_OUTPUT)
|
||||
|
||||
# batch of 2
|
||||
question = [{"role": "user", "content": "What is 3+5?"}]
|
||||
cot_question = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end.",
|
||||
}
|
||||
]
|
||||
question = tokenizer.apply_chat_template(
|
||||
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
cot_question = tokenizer.apply_chat_template(
|
||||
cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")
|
||||
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
self.assertEqual(responses[0], EXPECTED_OUTPUT)
|
||||
|
||||
def test_max_length_if_input_embeds(self):
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
@@ -840,7 +840,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
||||
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
||||
strict = version.parse(torch.__version__) < version.parse("2.7.0")
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=input_ids,
|
||||
cache_position=cache_position,
|
||||
|
||||
Reference in New Issue
Block a user