TFBart lables consider both pad token and -100 (#9847)
* TFBart lables consider both pad token and -100 * make style * fix for all other models Co-authored-by: kykim <kykim> Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
|
|||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -1234,7 +1235,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
"The BART Model with a language modeling head. Can be used for summarization.",
|
"The BART Model with a language modeling head. Can be used for summarization.",
|
||||||
BART_START_DOCSTRING,
|
BART_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
@@ -1322,6 +1323,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
|
inputs["labels"] = tf.where(
|
||||||
|
inputs["labels"] == self.config.pad_token_id,
|
||||||
|
tf.fill(shape_list(inputs["labels"]), -100),
|
||||||
|
inputs["labels"],
|
||||||
|
)
|
||||||
inputs["use_cache"] = False
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
@@ -1448,15 +1454,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
|||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def compute_loss(self, labels, logits):
|
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
||||||
from_logits=True,
|
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
return loss_fn(labels, reduced_logits)
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from ...modeling_tf_outputs import (
|
|||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -1251,7 +1252,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
|||||||
"The BLENDERBOT Model with a language modeling head. Can be used for summarization.",
|
"The BLENDERBOT Model with a language modeling head. Can be used for summarization.",
|
||||||
BLENDERBOT_START_DOCSTRING,
|
BLENDERBOT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
@@ -1352,6 +1353,12 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
|
inputs["labels"] = tf.where(
|
||||||
|
inputs["labels"] == self.config.pad_token_id,
|
||||||
|
tf.fill(shape_list(inputs["labels"]), -100),
|
||||||
|
inputs["labels"],
|
||||||
|
)
|
||||||
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
|
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
@@ -1477,16 +1484,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
|||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
|
|
||||||
def compute_loss(self, labels, logits):
|
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
||||||
from_logits=True,
|
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
return loss_fn(labels, reduced_logits)
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
|
|||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -1239,7 +1240,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
|||||||
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.",
|
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.",
|
||||||
BLENDERBOT_SMALL_START_DOCSTRING,
|
BLENDERBOT_SMALL_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel):
|
class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
@@ -1327,6 +1328,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
|
inputs["labels"] = tf.where(
|
||||||
|
inputs["labels"] == self.config.pad_token_id,
|
||||||
|
tf.fill(shape_list(inputs["labels"]), -100),
|
||||||
|
inputs["labels"],
|
||||||
|
)
|
||||||
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
|
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
@@ -1452,16 +1459,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
|
|
||||||
def compute_loss(self, labels, logits):
|
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
||||||
from_logits=True,
|
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
return loss_fn(labels, reduced_logits)
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from ...modeling_tf_outputs import (
|
|||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -1256,7 +1257,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
"The MARIAN Model with a language modeling head. Can be used for summarization.",
|
"The MARIAN Model with a language modeling head. Can be used for summarization.",
|
||||||
MARIAN_START_DOCSTRING,
|
MARIAN_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFMarianMTModel(TFMarianPreTrainedModel):
|
class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
@@ -1344,6 +1345,11 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
|
inputs["labels"] = tf.where(
|
||||||
|
inputs["labels"] == self.config.pad_token_id,
|
||||||
|
tf.fill(shape_list(inputs["labels"]), -100),
|
||||||
|
inputs["labels"],
|
||||||
|
)
|
||||||
inputs["use_cache"] = False
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
@@ -1471,16 +1477,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
|||||||
if cur_len == max_length - 1:
|
if cur_len == max_length - 1:
|
||||||
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
|
|
||||||
def compute_loss(self, labels, logits):
|
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
||||||
from_logits=True,
|
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
return loss_fn(labels, reduced_logits)
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
|
|||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -1257,7 +1258,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
|||||||
"The MBART Model with a language modeling head. Can be used for summarization.",
|
"The MBART Model with a language modeling head. Can be used for summarization.",
|
||||||
MBART_START_DOCSTRING,
|
MBART_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
@@ -1345,6 +1346,11 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
|
inputs["labels"] = tf.where(
|
||||||
|
inputs["labels"] == self.config.pad_token_id,
|
||||||
|
tf.fill(shape_list(inputs["labels"]), -100),
|
||||||
|
inputs["labels"],
|
||||||
|
)
|
||||||
inputs["use_cache"] = False
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(inputs["labels"], self.config.pad_token_id)
|
inputs["decoder_input_ids"] = shift_tokens_right(inputs["labels"], self.config.pad_token_id)
|
||||||
@@ -1469,16 +1475,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
|||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
|
|
||||||
def compute_loss(self, labels, logits):
|
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
||||||
from_logits=True,
|
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
return loss_fn(labels, reduced_logits)
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from ...modeling_tf_outputs import (
|
|||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -1270,7 +1271,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
|||||||
"The PEGASUS Model with a language modeling head. Can be used for summarization.",
|
"The PEGASUS Model with a language modeling head. Can be used for summarization.",
|
||||||
PEGASUS_START_DOCSTRING,
|
PEGASUS_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"model.encoder.embed_tokens.weight",
|
r"model.encoder.embed_tokens.weight",
|
||||||
r"model.decoder.embed_tokens.weight",
|
r"model.decoder.embed_tokens.weight",
|
||||||
@@ -1358,6 +1359,11 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
|
inputs["labels"] = tf.where(
|
||||||
|
inputs["labels"] == self.config.pad_token_id,
|
||||||
|
tf.fill(shape_list(inputs["labels"]), -100),
|
||||||
|
inputs["labels"],
|
||||||
|
)
|
||||||
inputs["use_cache"] = False
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
@@ -1484,16 +1490,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
|||||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
|
|
||||||
def compute_loss(self, labels, logits):
|
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
||||||
from_logits=True,
|
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
return loss_fn(labels, reduced_logits)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user