correct beam search sampling
This commit is contained in:
@@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# scores for each sentence in the beam
|
# scores for each sentence in the beam
|
||||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
if do_sample is False:
|
||||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||||
beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,))
|
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
||||||
|
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
|
||||||
|
else:
|
||||||
|
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
||||||
|
|
||||||
|
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
|
||||||
|
|
||||||
# cache compute states
|
# cache compute states
|
||||||
past = None
|
past = None
|
||||||
@@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
|
|
||||||
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
_scores = scores + tf.broadcast_to(
|
||||||
|
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
||||||
|
) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
next_token_logits = tf_top_k_top_p_filtering(
|
_scores = tf_top_k_top_p_filtering(
|
||||||
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size * num_beams, vocab_size)
|
||||||
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||||
|
_scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
|
||||||
|
|
||||||
next_tokens = tf.random.categorical(
|
next_tokens = tf.random.categorical(
|
||||||
next_token_logits, dtype=tf.int32, num_samples=2
|
_scores, dtype=tf.int32, num_samples=2 * num_beams
|
||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size, 2 * num_beams)
|
||||||
# Compute next scores
|
# Compute next scores
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
|
||||||
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
|
|
||||||
next_scores = _scores + tf.broadcast_to(
|
|
||||||
beam_scores[:, None], (batch_size * num_beams, 2)
|
|
||||||
) # (batch_size * num_beams, 2)
|
|
||||||
# Match shape of greedy beam search
|
|
||||||
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
|
||||||
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
|
||||||
else:
|
else:
|
||||||
# do greedy beam search
|
# do greedy beam search
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user