TFBart lables consider both pad token and -100 (#9847)

* TFBart lables consider both pad token and -100

* make style

* fix for all other models

Co-authored-by: kykim <kykim>
Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Kiyoung Kim
2021-02-01 07:31:29 +09:00
committed by GitHub
parent 22121e813e
commit 74f16b8276
6 changed files with 44 additions and 83 deletions

View File

@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
# Public API
from ...modeling_tf_utils import (
DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
@@ -1239,7 +1240,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_SMALL_START_DOCSTRING,
)
class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel):
class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
@@ -1327,6 +1328,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
)
if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
@@ -1452,16 +1459,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)