From 06f1692b023a701ab2bb443fa4f0bdd58c6bd234 Mon Sep 17 00:00:00 2001 From: Maurice Gonzenbach Date: Mon, 3 Aug 2020 16:21:23 +0200 Subject: [PATCH] Fix _shift_right function in TFT5PreTrainedModel (#6214) --- src/transformers/modeling_tf_t5.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 9b885d496f..9858b8ae76 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -783,8 +783,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): decoder_start_token_id is not None ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information" - # shift inputs to the right - shifted_input_ids = tf.zeros_like(input_ids, dtype=tf.int32) + shifted_input_ids = tf.cast(input_ids, tf.int32) shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) @@ -795,9 +794,12 @@ class TFT5PreTrainedModel(TFPreTrainedModel): shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - assert tf.math.reduce_any( - shifted_input_ids >= 0 - ).numpy(), "Verify that `labels` has only positive values and -100" + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids