Fix remaining issues in beam score calculation (#27808)
* Fix issues in add and is_done for BeamHypotheses * make newly added arguments optional for better compatibility * Directly use cur_len as generated_len, add note for retrocompatibility * update test expectation * make cur_len represents the length of the entire sequence including the decoder prompt * remove redundant if/else in testing
This commit is contained in:
@@ -224,8 +224,8 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
group_index: Optional[int] = 0,
|
group_index: Optional[int] = 0,
|
||||||
decoder_prompt_len: Optional[int] = 0,
|
decoder_prompt_len: Optional[int] = 0,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
# add up to the length which the next_scores is calculated on
|
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
|
cur_len = input_ids.shape[-1] + 1
|
||||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||||
|
|
||||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||||
@@ -279,15 +279,11 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
else:
|
else:
|
||||||
beam_index = None
|
beam_index = None
|
||||||
|
|
||||||
# skip the corner case where the very first generated token is eos_token
|
|
||||||
if decoder_prompt_len == input_ids.shape[-1]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._beam_hyps[batch_group_idx].add(
|
self._beam_hyps[batch_group_idx].add(
|
||||||
input_ids[batch_beam_idx].clone(),
|
input_ids[batch_beam_idx].clone(),
|
||||||
next_score.item(),
|
next_score.item(),
|
||||||
beam_indices=beam_index,
|
beam_indices=beam_index,
|
||||||
decoder_prompt_len=decoder_prompt_len,
|
generated_len=cur_len - decoder_prompt_len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted token since it is not eos_token
|
# add next predicted token since it is not eos_token
|
||||||
@@ -308,7 +304,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
# Check if we are done so that we can save a pad step if all(done)
|
# Check if we are done so that we can save a pad step if all(done)
|
||||||
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
||||||
next_scores[batch_idx].max().item(), cur_len
|
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
||||||
)
|
)
|
||||||
|
|
||||||
return UserDict(
|
return UserDict(
|
||||||
@@ -348,7 +344,8 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
final_score = final_beam_scores[batch_beam_idx].item()
|
final_score = final_beam_scores[batch_beam_idx].item()
|
||||||
final_tokens = input_ids[batch_beam_idx]
|
final_tokens = input_ids[batch_beam_idx]
|
||||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
|
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||||
|
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
||||||
|
|
||||||
# select the best hypotheses
|
# select the best hypotheses
|
||||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||||
@@ -560,8 +557,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
indicating to which beam the next tokens shall be added.
|
indicating to which beam the next tokens shall be added.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# add up to the length which the next_scores is calculated on
|
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
|
cur_len = input_ids.shape[-1] + 1
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||||
if self.num_beam_groups > 1:
|
if self.num_beam_groups > 1:
|
||||||
@@ -617,16 +614,11 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
else:
|
else:
|
||||||
beam_index = None
|
beam_index = None
|
||||||
|
|
||||||
# skip the corner case where the only constraint token is
|
|
||||||
# eos_token and the very first generated token is eos_token
|
|
||||||
if decoder_prompt_len == input_ids.shape[-1]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
beam_hyp.add(
|
beam_hyp.add(
|
||||||
input_ids[batch_beam_idx].clone(),
|
input_ids[batch_beam_idx].clone(),
|
||||||
next_score.item(),
|
next_score.item(),
|
||||||
beam_indices=beam_index,
|
beam_indices=beam_index,
|
||||||
decoder_prompt_len=decoder_prompt_len,
|
generated_len=cur_len - decoder_prompt_len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted token since it is not eos_token
|
# add next predicted token since it is not eos_token
|
||||||
@@ -660,7 +652,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
# Check if we are done so that we can save a pad step if all(done)
|
# Check if we are done so that we can save a pad step if all(done)
|
||||||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
||||||
next_scores[batch_idx].max().item(), cur_len
|
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
||||||
)
|
)
|
||||||
|
|
||||||
return UserDict(
|
return UserDict(
|
||||||
@@ -846,9 +838,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
||||||
if completes_constraint:
|
if completes_constraint:
|
||||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||||
beam_hyp.add(
|
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||||
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
|
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
||||||
)
|
|
||||||
ids_collect.append(beam_id)
|
ids_collect.append(beam_id)
|
||||||
|
|
||||||
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
|
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
|
||||||
@@ -859,7 +850,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||||
final_score = final_beam_scores[batch_beam_idx].item()
|
final_score = final_beam_scores[batch_beam_idx].item()
|
||||||
final_tokens = input_ids[batch_beam_idx]
|
final_tokens = input_ids[batch_beam_idx]
|
||||||
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
|
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||||
|
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
|
||||||
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -956,12 +948,17 @@ class BeamHypotheses:
|
|||||||
hyp: torch.LongTensor,
|
hyp: torch.LongTensor,
|
||||||
sum_logprobs: float,
|
sum_logprobs: float,
|
||||||
beam_indices: Optional[torch.LongTensor] = None,
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
decoder_prompt_len: Optional[int] = 0,
|
generated_len: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add a new hypothesis to the list.
|
Add a new hypothesis to the list.
|
||||||
"""
|
"""
|
||||||
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
|
if generated_len is not None:
|
||||||
|
score = sum_logprobs / (generated_len**self.length_penalty)
|
||||||
|
# This 'else' case exists for retrocompatibility
|
||||||
|
else:
|
||||||
|
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||||
|
|
||||||
if len(self) < self.num_beams or score > self.worst_score:
|
if len(self) < self.num_beams or score > self.worst_score:
|
||||||
self.beams.append((score, hyp, beam_indices))
|
self.beams.append((score, hyp, beam_indices))
|
||||||
if len(self) > self.num_beams:
|
if len(self) > self.num_beams:
|
||||||
@@ -971,7 +968,7 @@ class BeamHypotheses:
|
|||||||
else:
|
else:
|
||||||
self.worst_score = min(score, self.worst_score)
|
self.worst_score = min(score, self.worst_score)
|
||||||
|
|
||||||
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
|
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
|
||||||
"""
|
"""
|
||||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||||
one in the heap, then we are done with this sentence.
|
one in the heap, then we are done with this sentence.
|
||||||
@@ -987,7 +984,7 @@ class BeamHypotheses:
|
|||||||
# when `length_penalty` is positive. See the discussion below for more details.
|
# when `length_penalty` is positive. See the discussion below for more details.
|
||||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||||
elif self.early_stopping is False:
|
elif self.early_stopping is False:
|
||||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
|
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
||||||
ret = self.worst_score >= highest_attainable_score
|
ret = self.worst_score >= highest_attainable_score
|
||||||
return ret
|
return ret
|
||||||
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
||||||
@@ -996,9 +993,13 @@ class BeamHypotheses:
|
|||||||
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
||||||
# its max this way
|
# its max this way
|
||||||
if self.length_penalty > 0.0:
|
if self.length_penalty > 0.0:
|
||||||
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
|
if self.max_length <= decoder_prompt_len:
|
||||||
|
raise ValueError("max_length is not larger than decoder prompt length")
|
||||||
|
highest_attainable_score = (
|
||||||
|
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
|
||||||
|
)
|
||||||
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
||||||
else:
|
else:
|
||||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
|
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
||||||
ret = self.worst_score >= highest_attainable_score
|
ret = self.worst_score >= highest_attainable_score
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -633,10 +633,6 @@ class GenerationIntegrationTestsMixin:
|
|||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
"num_beams": 3,
|
"num_beams": 3,
|
||||||
}
|
}
|
||||||
if is_pt:
|
|
||||||
expectation = 20
|
|
||||||
else:
|
|
||||||
# TODO (joao): fix me
|
|
||||||
expectation = 13
|
expectation = 13
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
|||||||
@@ -800,7 +800,7 @@ class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
preds, scores = generate_step(pixel_values)
|
preds, scores = generate_step(pixel_values)
|
||||||
|
|
||||||
EXPECTED_SCORES = np.array([-0.64145195])
|
EXPECTED_SCORES = np.array([-0.5956343])
|
||||||
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
|
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
|
||||||
self.assertLessEqual(max_diff, 1e-4)
|
self.assertLessEqual(max_diff, 1e-4)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user