From 6a346f0358a40f89ec384d441233bf54cac44f6a Mon Sep 17 00:00:00 2001 From: Muennighoff <62820084+Muennighoff@users.noreply.github.com> Date: Thu, 21 Jan 2021 09:21:01 +0100 Subject: [PATCH] fix typo (#9708) * fix typo Co-authored-by: Suraj Patil --- src/transformers/modeling_tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 6c8b698e87..0f390c9a91 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -147,7 +147,7 @@ class TFCausalLanguageModelingLoss: loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE ) - # make sure only labels that are not equal to -100 do not affect loss + # make sure only labels that are not equal to -100 affect the loss active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)