From 7a89a3e4935cdd7b46765c5737665b10bfed1e28 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 4 Mar 2020 12:02:57 +0100 Subject: [PATCH] correct beam search sampling --- src/transformers/modeling_tf_utils.py | 36 ++++++++++++++++----------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 1dfeecdd8e..68151d93c5 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ] # scores for each sentence in the beam - beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) - beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 - beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,)) + if do_sample is False: + beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) + beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 + beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1) + else: + beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32) + + beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,)) # cache compute states past = None @@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: next_token_logits = next_token_logits / temperature + + scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) + _scores = scores + tf.broadcast_to( + beam_scores[:, None], (batch_size * num_beams, vocab_size) + ) # (batch_size * num_beams, vocab_size) + # Top-p/top-k filtering - next_token_logits = tf_top_k_top_p_filtering( - next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + _scores = tf_top_k_top_p_filtering( + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 ) # (batch_size * num_beams, vocab_size) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) + _scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size)) + next_tokens = tf.random.categorical( - next_token_logits, dtype=tf.int32, num_samples=2 - ) # (batch_size * num_beams, vocab_size) + _scores, dtype=tf.int32, num_samples=2 * num_beams + ) # (batch_size, 2 * num_beams) # Compute next scores - scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) - _scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2) - next_scores = _scores + tf.broadcast_to( - beam_scores[:, None], (batch_size * num_beams, 2) - ) # (batch_size * num_beams, 2) - # Match shape of greedy beam search - next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) - next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) + next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams) else: # do greedy beam search scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)