rename variable
This commit is contained in:
@@ -990,7 +990,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
next_scores, (batch_size, num_beams * vocab_size)
|
next_scores, (batch_size, num_beams * vocab_size)
|
||||||
) # (batch_size, num_beams * vocab_size)
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
|
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
||||||
|
|
||||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user