From 453079c7f843e7eb920a47fcdaa431413ac0fe72 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Wed, 15 Nov 2023 20:49:14 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=20Fix=20beam=20score?= =?UTF-8?q?=20calculation=20issue=20for=20decoder-only=20models=20(#27351)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix beam score calculation issue for decoder-only models * Update beam search test and fix code quality issue * Fix beam_sample, group_beam_search and constrained_beam_search * Split test for pytorch and TF, add documentation --------- Co-authored-by: Xin Qiu --- src/transformers/generation/beam_search.py | 42 +++++++++++++++++---- src/transformers/generation/utils.py | 16 ++++++++ tests/generation/test_framework_agnostic.py | 6 ++- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index 03334b6b61..a46859c88c 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -222,8 +222,10 @@ class BeamSearchScorer(BeamScorer): eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, group_index: Optional[int] = 0, + decoder_prompt_len: Optional[int] = 0, ) -> Dict[str, torch.Tensor]: - cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on + # add up to the length which the next_scores is calculated on + cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 batch_size = len(self._beam_hyps) // self.num_beam_groups if not (batch_size == (input_ids.shape[0] // self.group_size)): @@ -277,10 +279,15 @@ class BeamSearchScorer(BeamScorer): else: beam_index = None + # skip the corner case where the very first generated token is eos_token + if decoder_prompt_len == input_ids.shape[-1]: + continue + self._beam_hyps[batch_group_idx].add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, + decoder_prompt_len=decoder_prompt_len, ) else: # add next predicted token since it is not eos_token @@ -322,6 +329,7 @@ class BeamSearchScorer(BeamScorer): pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) // self.num_beam_groups @@ -340,7 +348,7 @@ class BeamSearchScorer(BeamScorer): final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) @@ -511,6 +519,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, ) -> Tuple[torch.Tensor]: r""" Args: @@ -535,7 +544,8 @@ class ConstrainedBeamSearchScorer(BeamScorer): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. beam_indices (`torch.LongTensor`, *optional*): Beam indices indicating to which beam hypothesis each token correspond. - + decoder_prompt_len (`int`, *optional*): + The length of prompt that is included in the input to decoder. Return: `UserDict`: A dictionary composed of the fields as defined above: @@ -550,7 +560,8 @@ class ConstrainedBeamSearchScorer(BeamScorer): indicating to which beam the next tokens shall be added. """ - cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on + # add up to the length which the next_scores is calculated on + cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 batch_size = len(self._beam_hyps) if not (batch_size == (input_ids.shape[0] // self.group_size)): if self.num_beam_groups > 1: @@ -606,10 +617,16 @@ class ConstrainedBeamSearchScorer(BeamScorer): else: beam_index = None + # skip the corner case where the only constraint token is + # eos_token and the very first generated token is eos_token + if decoder_prompt_len == input_ids.shape[-1]: + continue + beam_hyp.add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, + decoder_prompt_len=decoder_prompt_len, ) else: # add next predicted token since it is not eos_token @@ -805,6 +822,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) @@ -828,7 +846,9 @@ class ConstrainedBeamSearchScorer(BeamScorer): completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) if completes_constraint: beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) + beam_hyp.add( + final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len + ) ids_collect.append(beam_id) # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful @@ -839,7 +859,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score) + beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len) if len(ids_collect) >= self.num_beam_hyps_to_keep: break @@ -931,11 +951,17 @@ class BeamHypotheses: """ return len(self.beams) - def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None): + def add( + self, + hyp: torch.LongTensor, + sum_logprobs: float, + beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, + ): """ Add a new hypothesis to the list. """ - score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) + score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty) if len(self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp, beam_indices)) if len(self) > self.num_beams: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4dbfc36706..10ffffc37c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3172,6 +3172,8 @@ class GenerationMixin: beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -3246,6 +3248,7 @@ class GenerationMixin: pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, ) beam_scores = beam_outputs["next_beam_scores"] @@ -3281,6 +3284,7 @@ class GenerationMixin: eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, ) if return_dict_in_generate: @@ -3500,6 +3504,8 @@ class GenerationMixin: beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -3578,6 +3584,7 @@ class GenerationMixin: pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] @@ -3612,6 +3619,7 @@ class GenerationMixin: eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, ) if return_dict_in_generate: @@ -3837,6 +3845,8 @@ class GenerationMixin: beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -3924,6 +3934,7 @@ class GenerationMixin: eos_token_id=eos_token_id, beam_indices=process_beam_indices, group_index=beam_group_idx, + decoder_prompt_len=decoder_prompt_len, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] @@ -3993,6 +4004,7 @@ class GenerationMixin: eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, + decoder_prompt_len=decoder_prompt_len, ) if return_dict_in_generate: @@ -4220,6 +4232,8 @@ class GenerationMixin: beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -4298,6 +4312,7 @@ class GenerationMixin: pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] @@ -4331,6 +4346,7 @@ class GenerationMixin: eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, ) if return_dict_in_generate: diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py index 306cb15168..8a26980164 100644 --- a/tests/generation/test_framework_agnostic.py +++ b/tests/generation/test_framework_agnostic.py @@ -633,7 +633,11 @@ class GenerationIntegrationTestsMixin: "do_sample": False, "num_beams": 3, } - expectation = 13 + if is_pt: + expectation = 20 + else: + # TODO (joao): fix me + expectation = 13 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and"""