From 3d3e605affb792b78c918aac48f6bc82cfbf7e3e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Jun 2020 16:30:24 -0400 Subject: [PATCH] [cleanup] generate_beam_search comments (#5115) --- src/transformers/modeling_tf_utils.py | 8 +++----- src/transformers/modeling_utils.py | 24 +++++++++++------------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 54599e5465..4d31c00bd7 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if len(next_sent_beam) == num_beams: break - # Check if were done so that we can save a pad step if all(done) + # Check if we are done so that we can save a pad step if all(done) done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( - tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len + tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len ) # update next beam content @@ -1509,7 +1509,7 @@ class BeamHypotheses(object): else: self.worst_score = min(score, self.worst_score) - def is_done(self, best_sum_logprobs, cur_len=None): + def is_done(self, best_sum_logprobs, cur_len): """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. @@ -1520,8 +1520,6 @@ class BeamHypotheses(object): elif self.early_stopping: return True else: - if cur_len is None: - cur_len = self.max_length cur_score = best_sum_logprobs / cur_len ** self.length_penalty ret = self.worst_score >= cur_score return ret diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c6b3ff7d24..da072491e9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # for each sentence for batch_idx in range(batch_size): - # if we are done with this sentence + # if we are done with this sentence, add a pad token if done[batch_idx]: assert ( len(generated_hyps[batch_idx]) >= num_beams @@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch continue - # next sentence beam content + # next sentence beam content, this will get added to next_batch_beam next_sent_beam = [] # next tokens for this sentence @@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): token_id = beam_token_id % vocab_size effective_beam_id = batch_idx * num_beams + beam_id - # add to generated hypotheses if end of sentence or last iteration + # add to generated hypotheses if end of sentence if (eos_token_id is not None) and (token_id.item() == eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams @@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids[effective_beam_id].clone(), beam_token_score.item(), ) else: - # add next predicted token if it is not eos_token + # add next predicted token since it is not eos_token next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) - # the beam for next step is full + # once the beam for next step is full, don't add more tokens to it. if len(next_sent_beam) == num_beams: break - # Check if were done so that we can save a pad step if all(done) + # Check if we are done so that we can save a pad step if all(done) done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( - next_scores[batch_idx].max().item(), cur_len=cur_len + next_scores[batch_idx].max().item(), cur_len ) # update next beam content assert len(next_sent_beam) == num_beams, "Beam should always be full" next_batch_beam.extend(next_sent_beam) - assert len(next_batch_beam) == num_beams * (batch_idx + 1) + assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" # stop when we are done with each sentence if all(done): @@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) - # finalize all open beam hypotheses and end to generated hypotheses + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx in range(batch_size): if done[batch_idx]: continue @@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): sent_lengths[effective_batch_idx] = len(best_hyp) best.append(best_hyp) - # shorter batches are filled with pad_token + # shorter batches are padded if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`Pad_token_id` has to be defined" sent_max_len = min(sent_lengths.max().item() + 1, max_length) @@ -1731,7 +1731,7 @@ class BeamHypotheses(object): else: self.worst_score = min(score, self.worst_score) - def is_done(self, best_sum_logprobs, cur_len=None): + def is_done(self, best_sum_logprobs, cur_len): """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. @@ -1742,8 +1742,6 @@ class BeamHypotheses(object): elif self.early_stopping: return True else: - if cur_len is None: - cur_len = self.max_length cur_score = best_sum_logprobs / cur_len ** self.length_penalty ret = self.worst_score >= cur_score return ret