[cleanup] generate_beam_search comments (#5115)

This commit is contained in:
Sam Shleifer
2020-06-18 16:30:24 -04:00
committed by GitHub
parent ca2d0f98c4
commit 3d3e605aff
2 changed files with 14 additions and 18 deletions

View File

@@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# Check if were done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len
) )
# update next beam content # update next beam content
@@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
else: else:
self.worst_score = min(score, self.worst_score) self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None): def is_done(self, best_sum_logprobs, cur_len):
""" """
If there are enough hypotheses and that none of the hypotheses being generated If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence. can become better than the worst one in the heap, then we are done with this sentence.
@@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
elif self.early_stopping: elif self.early_stopping:
return True return True
else: else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score ret = self.worst_score >= cur_score
return ret return ret

View File

@@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# for each sentence # for each sentence
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# if we are done with this sentence # if we are done with this sentence, add a pad token
if done[batch_idx]: if done[batch_idx]:
assert ( assert (
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
@@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
# next sentence beam content # next sentence beam content, this will get added to next_batch_beam
next_sent_beam = [] next_sent_beam = []
# next tokens for this sentence # next tokens for this sentence
@@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
token_id = beam_token_id % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id): if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
@@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids[effective_beam_id].clone(), beam_token_score.item(), input_ids[effective_beam_id].clone(), beam_token_score.item(),
) )
else: else:
# add next predicted token if it is not eos_token # add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full # once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# Check if were done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=cur_len next_scores[batch_idx].max().item(), cur_len
) )
# update next beam content # update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full" assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam) next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1) assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
# stop when we are done with each sentence # stop when we are done with each sentence
if all(done): if all(done):
@@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
) )
# finalize all open beam hypotheses and end to generated hypotheses # finalize all open beam hypotheses and add to generated hypotheses
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
if done[batch_idx]: if done[batch_idx]:
continue continue
@@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths[effective_batch_idx] = len(best_hyp) sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp) best.append(best_hyp)
# shorter batches are filled with pad_token # shorter batches are padded
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined" assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length) sent_max_len = min(sent_lengths.max().item() + 1, max_length)
@@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
else: else:
self.worst_score = min(score, self.worst_score) self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None): def is_done(self, best_sum_logprobs, cur_len):
""" """
If there are enough hypotheses and that none of the hypotheses being generated If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence. can become better than the worst one in the heap, then we are done with this sentence.
@@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
elif self.early_stopping: elif self.early_stopping:
return True return True
else: else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score ret = self.worst_score >= cur_score
return ret return ret