From ca1330f0b2b98583da694eb39d6a3f90ab4261c7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Mar 2020 14:22:16 +0100 Subject: [PATCH] do not mess with the negative sign --- src/transformers/modeling_tf_utils.py | 37 ++++++++++++++------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index cfaa10b046..e714df4731 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -894,7 +894,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times if do_sample is False: beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) - beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 + beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9) beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1) else: beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32) @@ -926,6 +926,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if temperature != 1.0: next_token_logits = next_token_logits / temperature +# calculate log softmax score + scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) + + # set eos token prob to zero if min_length is not reached + if eos_token_ids is not None and cur_len < min_length: + # create eos_token_ids boolean mask + is_token_logit_eos_token = tf.convert_to_tensor( + [True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool + ) + eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) + + scores = set_tensor_by_indices_to_value( + scores, eos_token_indices_mask, -float("inf") + ) + if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 @@ -937,24 +952,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): [True if token in banned_tokens_slice else False for token in range(vocab_size)] ) - next_token_logits = set_tensor_by_indices_to_value( - next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + scores = set_tensor_by_indices_to_value( + scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") ) - # set eos token prob to zero if min_length is not reached - if eos_token_ids is not None and cur_len < min_length: - # create eos_token_ids boolean mask - is_token_logit_eos_token = tf.convert_to_tensor( - [True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool - ) - eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) - - next_token_logits = set_tensor_by_indices_to_value( - next_token_logits, eos_token_indices_mask, -float("inf") - ) - - # calculate log softmax score - scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) assert shape_list(scores) == [batch_size * num_beams, vocab_size] if do_sample: @@ -991,6 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) # (batch_size, num_beams * vocab_size) next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True) + print(next_tokens) assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams] @@ -1064,7 +1066,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # re-order batch input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx]) input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1) - # re-order internal states if past: past = self._reorder_cache(past, beam_idx)