rename variable
This commit is contained in:
@@ -990,7 +990,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
next_scores, (batch_size, num_beams * vocab_size)
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
||||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user