[cleanup] generate_beam_search comments (#5115)
This commit is contained in:
@@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if len(next_sent_beam) == num_beams:
|
if len(next_sent_beam) == num_beams:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check if were done so that we can save a pad step if all(done)
|
# Check if we are done so that we can save a pad step if all(done)
|
||||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||||
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
|
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# update next beam content
|
# update next beam content
|
||||||
@@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
|
|||||||
else:
|
else:
|
||||||
self.worst_score = min(score, self.worst_score)
|
self.worst_score = min(score, self.worst_score)
|
||||||
|
|
||||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
def is_done(self, best_sum_logprobs, cur_len):
|
||||||
"""
|
"""
|
||||||
If there are enough hypotheses and that none of the hypotheses being generated
|
If there are enough hypotheses and that none of the hypotheses being generated
|
||||||
can become better than the worst one in the heap, then we are done with this sentence.
|
can become better than the worst one in the heap, then we are done with this sentence.
|
||||||
@@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
|
|||||||
elif self.early_stopping:
|
elif self.early_stopping:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
if cur_len is None:
|
|
||||||
cur_len = self.max_length
|
|
||||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||||
ret = self.worst_score >= cur_score
|
ret = self.worst_score >= cur_score
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
# for each sentence
|
# for each sentence
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
|
|
||||||
# if we are done with this sentence
|
# if we are done with this sentence, add a pad token
|
||||||
if done[batch_idx]:
|
if done[batch_idx]:
|
||||||
assert (
|
assert (
|
||||||
len(generated_hyps[batch_idx]) >= num_beams
|
len(generated_hyps[batch_idx]) >= num_beams
|
||||||
@@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# next sentence beam content
|
# next sentence beam content, this will get added to next_batch_beam
|
||||||
next_sent_beam = []
|
next_sent_beam = []
|
||||||
|
|
||||||
# next tokens for this sentence
|
# next tokens for this sentence
|
||||||
@@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
token_id = beam_token_id % vocab_size
|
token_id = beam_token_id % vocab_size
|
||||||
|
|
||||||
effective_beam_id = batch_idx * num_beams + beam_id
|
effective_beam_id = batch_idx * num_beams + beam_id
|
||||||
# add to generated hypotheses if end of sentence or last iteration
|
# add to generated hypotheses if end of sentence
|
||||||
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
||||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||||
@@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted token if it is not eos_token
|
# add next predicted token since it is not eos_token
|
||||||
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||||
|
|
||||||
# the beam for next step is full
|
# once the beam for next step is full, don't add more tokens to it.
|
||||||
if len(next_sent_beam) == num_beams:
|
if len(next_sent_beam) == num_beams:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check if were done so that we can save a pad step if all(done)
|
# Check if we are done so that we can save a pad step if all(done)
|
||||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||||
next_scores[batch_idx].max().item(), cur_len=cur_len
|
next_scores[batch_idx].max().item(), cur_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# update next beam content
|
# update next beam content
|
||||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||||
next_batch_beam.extend(next_sent_beam)
|
next_batch_beam.extend(next_sent_beam)
|
||||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
|
||||||
|
|
||||||
# stop when we are done with each sentence
|
# stop when we are done with each sentence
|
||||||
if all(done):
|
if all(done):
|
||||||
@@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# finalize all open beam hypotheses and end to generated hypotheses
|
# finalize all open beam hypotheses and add to generated hypotheses
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
if done[batch_idx]:
|
if done[batch_idx]:
|
||||||
continue
|
continue
|
||||||
@@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
sent_lengths[effective_batch_idx] = len(best_hyp)
|
sent_lengths[effective_batch_idx] = len(best_hyp)
|
||||||
best.append(best_hyp)
|
best.append(best_hyp)
|
||||||
|
|
||||||
# shorter batches are filled with pad_token
|
# shorter batches are padded
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||||
@@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
|
|||||||
else:
|
else:
|
||||||
self.worst_score = min(score, self.worst_score)
|
self.worst_score = min(score, self.worst_score)
|
||||||
|
|
||||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
def is_done(self, best_sum_logprobs, cur_len):
|
||||||
"""
|
"""
|
||||||
If there are enough hypotheses and that none of the hypotheses being generated
|
If there are enough hypotheses and that none of the hypotheses being generated
|
||||||
can become better than the worst one in the heap, then we are done with this sentence.
|
can become better than the worst one in the heap, then we are done with this sentence.
|
||||||
@@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
|
|||||||
elif self.early_stopping:
|
elif self.early_stopping:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
if cur_len is None:
|
|
||||||
cur_len = self.max_length
|
|
||||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||||
ret = self.worst_score >= cur_score
|
ret = self.worst_score >= cur_score
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
Reference in New Issue
Block a user