delete print and make style
This commit is contained in:
@@ -926,7 +926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
|
|
||||||
# calculate log softmax score
|
# calculate log softmax score
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
@@ -937,9 +937,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
)
|
)
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
scores = set_tensor_by_indices_to_value(
|
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
|
||||||
scores, eos_token_indices_mask, -float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
if no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size > 0:
|
||||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
@@ -992,7 +990,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
) # (batch_size, num_beams * vocab_size)
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
||||||
print(next_tokens)
|
|
||||||
|
|
||||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user