From 9b8ee8cea0c5c66c0850b880aa8e659e01211dd9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Mar 2020 14:32:21 +0100 Subject: [PATCH] delete print and make style --- src/transformers/modeling_tf_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e714df4731..8cab7619a0 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -926,7 +926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if temperature != 1.0: next_token_logits = next_token_logits / temperature -# calculate log softmax score + # calculate log softmax score scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) # set eos token prob to zero if min_length is not reached @@ -937,9 +937,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) - scores = set_tensor_by_indices_to_value( - scores, eos_token_indices_mask, -float("inf") - ) + scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf")) if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams @@ -992,7 +990,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) # (batch_size, num_beams * vocab_size) next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True) - print(next_tokens) assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]