From 74f16b82765a05eccee45e80d79370202a958873 Mon Sep 17 00:00:00 2001 From: Kiyoung Kim Date: Mon, 1 Feb 2021 07:31:29 +0900 Subject: [PATCH] TFBart lables consider both pad token and -100 (#9847) * TFBart lables consider both pad token and -100 * make style * fix for all other models Co-authored-by: kykim Co-authored-by: patrickvonplaten --- .../models/bart/modeling_tf_bart.py | 20 ++++++----------- .../blenderbot/modeling_tf_blenderbot.py | 22 +++++++------------ .../modeling_tf_blenderbot_small.py | 22 +++++++------------ .../models/marian/modeling_tf_marian.py | 21 ++++++------------ .../models/mbart/modeling_tf_mbart.py | 21 ++++++------------ .../models/pegasus/modeling_tf_pegasus.py | 21 ++++++------------ 6 files changed, 44 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 8c5e641a59..a75fb4ceb5 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -38,6 +38,7 @@ from ...modeling_tf_outputs import ( # Public API from ...modeling_tf_utils import ( DUMMY_INPUTS, + TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, @@ -1234,7 +1235,7 @@ class TFBartModel(TFBartPretrainedModel): "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING, ) -class TFBartForConditionalGeneration(TFBartPretrainedModel): +class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1322,6 +1323,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ) if inputs["labels"] is not None: + inputs["labels"] = tf.where( + inputs["labels"] == self.config.pad_token_id, + tf.fill(shape_list(inputs["labels"]), -100), + inputs["labels"], + ) inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: inputs["decoder_input_ids"] = shift_tokens_right( @@ -1448,15 +1454,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) else: return logits - - def compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE, - ) - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits) diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index cffc6095ad..b5c7d80c7f 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -40,6 +40,7 @@ from ...modeling_tf_outputs import ( # Public API from ...modeling_tf_utils import ( DUMMY_INPUTS, + TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, @@ -1251,7 +1252,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): "The BLENDERBOT Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING, ) -class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): +class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausalLanguageModelingLoss): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1352,6 +1353,12 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ) if inputs["labels"] is not None: + inputs["labels"] = tf.where( + inputs["labels"] == self.config.pad_token_id, + tf.fill(shape_list(inputs["labels"]), -100), + inputs["labels"], + ) + inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: inputs["decoder_input_ids"] = shift_tokens_right( inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id @@ -1477,16 +1484,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) else: return logits - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss - def compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE, - ) - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits) diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index a1b5d26dbd..beed93b227 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -38,6 +38,7 @@ from ...modeling_tf_outputs import ( # Public API from ...modeling_tf_utils import ( DUMMY_INPUTS, + TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, @@ -1239,7 +1240,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): "The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.", BLENDERBOT_SMALL_START_DOCSTRING, ) -class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel): +class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel, TFCausalLanguageModelingLoss): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1327,6 +1328,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ) if inputs["labels"] is not None: + inputs["labels"] = tf.where( + inputs["labels"] == self.config.pad_token_id, + tf.fill(shape_list(inputs["labels"]), -100), + inputs["labels"], + ) + inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: inputs["decoder_input_ids"] = shift_tokens_right( inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id @@ -1452,16 +1459,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) else: return logits - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss - def compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE, - ) - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits) diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 16ec9269ab..dbc4c80016 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -39,6 +39,7 @@ from ...modeling_tf_outputs import ( # Public API from ...modeling_tf_utils import ( DUMMY_INPUTS, + TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, @@ -1256,7 +1257,7 @@ class TFMarianModel(TFMarianPreTrainedModel): "The MARIAN Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING, ) -class TFMarianMTModel(TFMarianPreTrainedModel): +class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1344,6 +1345,11 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ) if inputs["labels"] is not None: + inputs["labels"] = tf.where( + inputs["labels"] == self.config.pad_token_id, + tf.fill(shape_list(inputs["labels"]), -100), + inputs["labels"], + ) inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: inputs["decoder_input_ids"] = shift_tokens_right( @@ -1471,16 +1477,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel): if cur_len == max_length - 1: logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) return logits - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss - def compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE, - ) - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits) diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 8562b0eb17..71ea3d0877 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -38,6 +38,7 @@ from ...modeling_tf_outputs import ( # Public API from ...modeling_tf_utils import ( DUMMY_INPUTS, + TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, @@ -1257,7 +1258,7 @@ class TFMBartModel(TFMBartPreTrainedModel): "The MBART Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING, ) -class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): +class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageModelingLoss): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1345,6 +1346,11 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ) if inputs["labels"] is not None: + inputs["labels"] = tf.where( + inputs["labels"] == self.config.pad_token_id, + tf.fill(shape_list(inputs["labels"]), -100), + inputs["labels"], + ) inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: inputs["decoder_input_ids"] = shift_tokens_right(inputs["labels"], self.config.pad_token_id) @@ -1469,16 +1475,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) else: return logits - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss - def compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE, - ) - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits) diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index a218076d76..57908e223c 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -39,6 +39,7 @@ from ...modeling_tf_outputs import ( # Public API from ...modeling_tf_utils import ( DUMMY_INPUTS, + TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, @@ -1270,7 +1271,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING, ) -class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): +class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1358,6 +1359,11 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ) if inputs["labels"] is not None: + inputs["labels"] = tf.where( + inputs["labels"] == self.config.pad_token_id, + tf.fill(shape_list(inputs["labels"]), -100), + inputs["labels"], + ) inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: inputs["decoder_input_ids"] = shift_tokens_right( @@ -1484,16 +1490,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) else: return logits - - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss - def compute_loss(self, labels, logits): - """CrossEntropyLoss that ignores pad tokens""" - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE, - ) - melted_labels = tf.reshape(labels, (-1,)) - active_loss = tf.not_equal(melted_labels, self.config.pad_token_id) - reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) - labels = tf.boolean_mask(melted_labels, active_loss) - return loss_fn(labels, reduced_logits)