diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e36417269d..c778e9e012 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3764,11 +3764,11 @@ class GenerationMixin(ContinuousMixin): return gathered_tensor @staticmethod - def _beam_search_has_unfinished_sequences( + def _check_early_stop_heuristic( + is_early_stop_heuristic_unsatisfied: torch.Tensor, running_beam_scores: torch.Tensor, beam_scores: torch.Tensor, is_sent_finished: torch.Tensor, - next_token_hits_stopping_criteria: torch.Tensor, cur_len: int, max_length: int, decoder_prompt_len: int, @@ -3776,34 +3776,52 @@ class GenerationMixin(ContinuousMixin): length_penalty: float, ): """ - Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False + Determine whether early stopping is possible by checking if the best possible score of running beams + could still improve upon the finished ones. + + Mechanism: + - Without a length penalty, beam scores typically decrease as more tokens are generated. + So, if the *best possible* score from any running beam is already worse than the *worst* finished beam, + we can safely stop early. + - With a length penalty, scores may increase with longer sequences. In this case, we use heuristics + to estimate the best possible score — though this estimate may not always be correct — and stop + if no further improvement seems likely. + + We apply different heuristics depending on the value of `early_stopping`: + 1. `early_stopping == False`: + -> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length. + -> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + + 2. `early_stopping == "never"`: + -> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`. + -> A positive length penalty favors longer sequences, so we use `max_length` in that case. + + NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and + `length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce + better sequences (prior to 2022), and changing it is BC breaking. """ - # a. Can the open beams improve the top completed scores? - # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. - # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the - # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use - # `max_length` there. - # !! - # Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization - # does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences - # compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut - # generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite - # its name. These modifications were empirically found in the past (prior to 2022) to produce better quality - # generations, and changing them is BC breaking. - # For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`. - # See the discussion below for more details. - # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 - # !! if early_stopping == "never" and length_penalty > 0.0: best_hypothetical_length = max_length - decoder_prompt_len else: best_hypothetical_length = cur_len - decoder_prompt_len - - # best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before - # applying length penalty best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) - improvement_possible = torch.any(best_possible_running_score > worst_finished_score) + return is_early_stop_heuristic_unsatisfied & torch.any( + best_possible_running_score > worst_finished_score, dim=-1, keepdim=True + ) + + @staticmethod + def _beam_search_has_unfinished_sequences( + is_early_stop_heuristic_unsatisfied: torch.Tensor, + is_sent_finished: torch.Tensor, + next_token_hits_stopping_criteria: torch.Tensor, + early_stopping: Union[bool, str], + ): + """ + Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False + """ + # a. Can the open beams improve the top completed scores? + improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied) # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is # enabled, where we want to finish as soon as all beams have a completed sequence. @@ -3899,6 +3917,7 @@ class GenerationMixin(ContinuousMixin): topk_log_probs: torch.Tensor, beam_indices: torch.Tensor, topk_running_beam_indices: torch.Tensor, + is_early_stop_heuristic_unsatisfied: torch.Tensor, is_sent_finished: torch.Tensor, next_token_hits_stopping_criteria: torch.Tensor, top_num_beam_mask: torch.Tensor, @@ -3923,6 +3942,9 @@ class GenerationMixin(ContinuousMixin): # - make sure no scores can be added anymore if beam is full and early stopping is on beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True) topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9 + # - make sure no scores can be added anymore if improvement is not possible + topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9 + # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9 @@ -4074,6 +4096,9 @@ class GenerationMixin(ContinuousMixin): # per batch, beam-item state bit indicating if sentence has finished. is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device) + # per batch state bit indicating if there is a possibility to improve the best finished sentence. + is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device) + # per batch, beam-item state bit indicating if there are valid continuations. next_token_hits_stopping_criteria = torch.zeros( (batch_size, num_beams), dtype=torch.bool, device=input_ids.device @@ -4186,6 +4211,7 @@ class GenerationMixin(ContinuousMixin): topk_log_probs=topk_log_probs, beam_indices=beam_indices, topk_running_beam_indices=topk_running_beam_indices, + is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied, is_sent_finished=is_sent_finished, next_token_hits_stopping_criteria=next_token_hits_stopping_criteria, top_num_beam_mask=top_num_beam_mask, @@ -4207,16 +4233,22 @@ class GenerationMixin(ContinuousMixin): ) cur_len = cur_len + 1 + is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic( + is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied, + running_beam_scores=running_beam_scores, + beam_scores=beam_scores, + is_sent_finished=is_sent_finished, + cur_len=cur_len, + max_length=max_length, + decoder_prompt_len=decoder_prompt_len, + early_stopping=early_stopping, + length_penalty=length_penalty, + ) this_peer_finished = not self._beam_search_has_unfinished_sequences( - running_beam_scores, - beam_scores, + is_early_stop_heuristic_unsatisfied, is_sent_finished, next_token_hits_stopping_criteria, - cur_len, - max_length, - decoder_prompt_len, early_stopping, - length_penalty, ) # 5. prepare outputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c6fd59f68b..9edb1fe99f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2887,6 +2887,53 @@ class GenerationIntegrationTests(unittest.TestCase): ], ) + @slow + def test_beam_search_early_stop_heuristic(self): + """Regression test for #38778 (early stopping needs to be tracked at a batch level)""" + EXPECTED_OUTPUT = ( + "<|user|>\nWhat is 3+5?\n<|assistant|>\nThe sum of 3 and 5 is 8. \n\nSo, 3 + 5 = 8. \n\n" + "Let's confirm this using Python code:\n\n```python\n# Define the numbers\nnum1 = 3\nnum2 = 5\n\n" + "# Calculate the sum\nresult = num1 + num2\n\n# Print the result\nprint(result)\n```\n" + "```output\n8\n```\nThe sum of 3 and 5 is \\(\\boxed{8}\\)." + ) + + model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left") + generation_config = GenerationConfig( + num_beams=10, + max_new_tokens=256, + length_penalty=2, + ) + # batch of 1 + question = [{"role": "user", "content": "What is 3+5?"}] + question = tokenizer.apply_chat_template( + question, tokenize=False, add_generation_prompt=True, return_tensors="pt" + ) + inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda") + outputs = model.generate(**inputs, generation_config=generation_config) + responses = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertEqual(responses[0], EXPECTED_OUTPUT) + + # batch of 2 + question = [{"role": "user", "content": "What is 3+5?"}] + cot_question = [ + { + "role": "user", + "content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end.", + } + ] + question = tokenizer.apply_chat_template( + question, tokenize=False, add_generation_prompt=True, return_tensors="pt" + ) + cot_question = tokenizer.apply_chat_template( + cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt" + ) + inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda") + + outputs = model.generate(**inputs, generation_config=generation_config) + responses = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertEqual(responses[0], EXPECTED_OUTPUT) + def test_max_length_if_input_embeds(self): article = "Today a dragon flew over Paris." model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 4c6b352168..0c3c90768d 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -840,7 +840,7 @@ class CacheExportIntegrationTest(unittest.TestCase): input_ids = torch.zeros((1, 3), dtype=torch.long) cache_position = torch.tensor([0, 1, 2], dtype=torch.long) dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}} - strict = version.parse(torch.__version__) != version.parse("2.7.0") + strict = version.parse(torch.__version__) < version.parse("2.7.0") exported_program = exportable_module.export( input_ids=input_ids, cache_position=cache_position,