* fix typo

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Muennighoff
2021-01-21 09:21:01 +01:00
committed by GitHub
parent 4a20b7c450
commit 6a346f0358

View File

@@ -147,7 +147,7 @@ class TFCausalLanguageModelingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100 do not affect loss
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)