Fix remaining issues in beam score calculation (#27808)
* Fix issues in add and is_done for BeamHypotheses * make newly added arguments optional for better compatibility * Directly use cur_len as generated_len, add note for retrocompatibility * update test expectation * make cur_len represents the length of the entire sequence including the decoder prompt * remove redundant if/else in testing
This commit is contained in:
@@ -224,8 +224,8 @@ class BeamSearchScorer(BeamScorer):
|
||||
group_index: Optional[int] = 0,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# add up to the length which the next_scores is calculated on
|
||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
|
||||
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||
cur_len = input_ids.shape[-1] + 1
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
@@ -279,15 +279,11 @@ class BeamSearchScorer(BeamScorer):
|
||||
else:
|
||||
beam_index = None
|
||||
|
||||
# skip the corner case where the very first generated token is eos_token
|
||||
if decoder_prompt_len == input_ids.shape[-1]:
|
||||
continue
|
||||
|
||||
self._beam_hyps[batch_group_idx].add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
generated_len=cur_len - decoder_prompt_len,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
@@ -308,7 +304,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len
|
||||
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
||||
)
|
||||
|
||||
return UserDict(
|
||||
@@ -348,7 +344,8 @@ class BeamSearchScorer(BeamScorer):
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
|
||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
@@ -560,8 +557,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
indicating to which beam the next tokens shall be added.
|
||||
"""
|
||||
|
||||
# add up to the length which the next_scores is calculated on
|
||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
|
||||
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||
cur_len = input_ids.shape[-1] + 1
|
||||
batch_size = len(self._beam_hyps)
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
if self.num_beam_groups > 1:
|
||||
@@ -617,16 +614,11 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
else:
|
||||
beam_index = None
|
||||
|
||||
# skip the corner case where the only constraint token is
|
||||
# eos_token and the very first generated token is eos_token
|
||||
if decoder_prompt_len == input_ids.shape[-1]:
|
||||
continue
|
||||
|
||||
beam_hyp.add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
generated_len=cur_len - decoder_prompt_len,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
@@ -660,7 +652,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len
|
||||
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
||||
)
|
||||
|
||||
return UserDict(
|
||||
@@ -846,9 +838,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
||||
if completes_constraint:
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
beam_hyp.add(
|
||||
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
|
||||
)
|
||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
||||
ids_collect.append(beam_id)
|
||||
|
||||
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
|
||||
@@ -859,7 +850,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
|
||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
|
||||
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
||||
break
|
||||
|
||||
@@ -956,12 +948,17 @@ class BeamHypotheses:
|
||||
hyp: torch.LongTensor,
|
||||
sum_logprobs: float,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
generated_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
|
||||
if generated_len is not None:
|
||||
score = sum_logprobs / (generated_len**self.length_penalty)
|
||||
# This 'else' case exists for retrocompatibility
|
||||
else:
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp, beam_indices))
|
||||
if len(self) > self.num_beams:
|
||||
@@ -971,7 +968,7 @@ class BeamHypotheses:
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
|
||||
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||
one in the heap, then we are done with this sentence.
|
||||
@@ -987,7 +984,7 @@ class BeamHypotheses:
|
||||
# when `length_penalty` is positive. See the discussion below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
elif self.early_stopping is False:
|
||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
||||
ret = self.worst_score >= highest_attainable_score
|
||||
return ret
|
||||
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
||||
@@ -996,9 +993,13 @@ class BeamHypotheses:
|
||||
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
||||
# its max this way
|
||||
if self.length_penalty > 0.0:
|
||||
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
|
||||
if self.max_length <= decoder_prompt_len:
|
||||
raise ValueError("max_length is not larger than decoder prompt length")
|
||||
highest_attainable_score = (
|
||||
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
|
||||
)
|
||||
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
||||
else:
|
||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
||||
ret = self.worst_score >= highest_attainable_score
|
||||
return ret
|
||||
|
||||
@@ -633,10 +633,6 @@ class GenerationIntegrationTestsMixin:
|
||||
"do_sample": False,
|
||||
"num_beams": 3,
|
||||
}
|
||||
if is_pt:
|
||||
expectation = 20
|
||||
else:
|
||||
# TODO (joao): fix me
|
||||
expectation = 13
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
@@ -800,7 +800,7 @@ class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
preds, scores = generate_step(pixel_values)
|
||||
|
||||
EXPECTED_SCORES = np.array([-0.64145195])
|
||||
EXPECTED_SCORES = np.array([-0.5956343])
|
||||
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user