fix typo (#9708)
* fix typo Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -147,7 +147,7 @@ class TFCausalLanguageModelingLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100 do not affect loss
|
# make sure only labels that are not equal to -100 affect the loss
|
||||||
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|||||||
Reference in New Issue
Block a user