fix typo (#9708)
* fix typo Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -147,7 +147,7 @@ class TFCausalLanguageModelingLoss:
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
# make sure only labels that are not equal to -100 do not affect loss
|
||||
# make sure only labels that are not equal to -100 affect the loss
|
||||
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||
|
||||
Reference in New Issue
Block a user