From 724e51c6e605639995e6cc0c6f7de99a749ba868 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 10 Feb 2022 15:47:02 +0100 Subject: [PATCH] Compute loss independent from decoder for TF EncDec models (as #14139) (#15175) * Compute loss independent from decoder (as 14139) * fix expected seq_len + style * Apply the same change to TFVisionEncoderDecoderModel * fix style * Add case with labels in equivalence test * uncomment * Add case with labels in equivalence test * add decoder_token_labels * use hf_compute_loss * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add copied from Co-authored-by: ydshieh Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- .../modeling_tf_encoder_decoder.py | 70 +++++++++++++++--- .../modeling_tf_vision_encoder_decoder.py | 71 ++++++++++++++++--- tests/test_modeling_tf_encoder_decoder.py | 26 ++++++- ...test_modeling_tf_vision_encoder_decoder.py | 25 ++++++- 4 files changed, 171 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 63ce2c87ac..8ba4ae31b8 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -16,6 +16,7 @@ import tempfile +import warnings from typing import Optional import tensorflow as tf @@ -29,7 +30,13 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFPreTrainedModel, + get_initializer, + input_processing, + shape_list, +) from ...utils import logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM @@ -40,6 +47,13 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "EncoderDecoderConfig" +DEPRECATION_WARNING = ( + "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the " + "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning " + "a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no " + "need to pass them yourself anymore." +) + ENCODER_DECODER_START_DOCSTRING = r""" This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via @@ -145,8 +159,36 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r""" """ +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids + ) + + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) -class TFEncoderDecoderModel(TFPreTrainedModel): +class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): r""" [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one of the base model classes of the library as encoder and another one as decoder when created with the @@ -566,6 +608,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + decoder_processing_inputs = { "func": self.decoder.call, "config": self.decoder.config, @@ -574,7 +621,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel): "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": attention_mask, "inputs_embeds": decoder_inputs_embeds, - "labels": labels, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "use_cache": use_cache, @@ -592,12 +638,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel): decoder_inputs = input_processing(**decoder_processing_inputs) decoder_outputs = self.decoder(**decoder_inputs) - loss = None if decoder_inputs["labels"] is None else decoder_outputs[0] - logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1] - past_key_values = None + logits = decoder_outputs[0] + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) + loss = self.hf_compute_loss(labels, logits) + + past_key_values = None if decoder_inputs["use_cache"]: - past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2] + past_key_values = decoder_outputs[1] # The starting index of the remaining elements in `decoder_outputs` start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) @@ -611,7 +662,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel): return output return TFSeq2SeqLMOutput( - loss=decoder_outputs.loss, + loss=loss, logits=decoder_outputs.logits, past_key_values=past, decoder_hidden_states=decoder_outputs.hidden_states, @@ -693,6 +744,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported." diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index f02c595eed..06bcbf7c4b 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -16,6 +16,7 @@ import tempfile +import warnings from typing import Optional import tensorflow as tf @@ -29,7 +30,13 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFPreTrainedModel, + get_initializer, + input_processing, + shape_list, +) from ...utils import logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM @@ -40,6 +47,13 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "VisionEncoderDecoderConfig" +DEPRECATION_WARNING = ( + "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the " + "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning " + "a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no " + "need to pass them yourself anymore." +) + VISION_ENCODER_DECODER_START_DOCSTRING = r""" This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via @@ -134,8 +148,37 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" """ +# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids + ) + + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + @add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) -class TFVisionEncoderDecoderModel(TFPreTrainedModel): +class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): r""" [`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one of the base vision model classes of the library as encoder and another one of the base model classes as @@ -594,6 +637,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + batch_size, sequence_length = shape_list(encoder_hidden_states)[:2] encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32) @@ -605,7 +653,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "inputs_embeds": decoder_inputs_embeds, - "labels": labels, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "use_cache": use_cache, @@ -622,12 +669,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): decoder_inputs = input_processing(**decoder_processing_inputs) decoder_outputs = self.decoder(**decoder_inputs) - loss = None if decoder_inputs["labels"] is None else decoder_outputs[0] - logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1] - past_key_values = None + logits = decoder_outputs[0] + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) + loss = self.hf_compute_loss(labels, logits) + + past_key_values = None if decoder_inputs["use_cache"]: - past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2] + past_key_values = decoder_outputs[1] # The starting index of the remaining elements in `decoder_outputs` start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) @@ -641,7 +693,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): return output return TFSeq2SeqLMOutput( - loss=decoder_outputs.loss, + loss=loss, logits=decoder_outputs.logits, past_key_values=past, decoder_hidden_states=decoder_outputs.hidden_states, @@ -715,6 +767,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported." diff --git a/tests/test_modeling_tf_encoder_decoder.py b/tests/test_modeling_tf_encoder_decoder.py index d5b0a9dd61..3655348ab5 100644 --- a/tests/test_modeling_tf_encoder_decoder.py +++ b/tests/test_modeling_tf_encoder_decoder.py @@ -14,6 +14,7 @@ # limitations under the License. +import copy import os import tempfile import unittest @@ -237,7 +238,7 @@ class TFEncoderDecoderMixin: ) # Make sure `loss` exist - assert "loss" in outputs_encoder_decoder + self.assertIn("loss", outputs_encoder_decoder) batch_size, seq_len = decoder_input_ids.shape expected_shape = (batch_size, seq_len, decoder_config.vocab_size) @@ -319,12 +320,18 @@ class TFEncoderDecoderMixin: # prepare inputs tf_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()} + if "labels" in pt_inputs: + pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor) with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - tf_outputs = tf_model(**inputs_dict).to_tuple() + tf_outputs = tf_model(**inputs_dict) + if "loss" in tf_outputs: + tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss) + tf_outputs = tf_outputs.to_tuple() self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch") + for tf_output, pt_output in zip(tf_outputs, pt_outputs): self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3) @@ -339,8 +346,12 @@ class TFEncoderDecoderMixin: # This is only for copying some specific attributes of this particular model. tf_model_loaded.config = pt_model.config - tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple() + tf_outputs_loaded = tf_model_loaded(**inputs_dict) + if "loss" in tf_outputs_loaded: + tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss) + tf_outputs_loaded = tf_outputs_loaded.to_tuple() self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch") + for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs): self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3) @@ -435,6 +446,8 @@ class TFEncoderDecoderMixin: def test_pt_tf_equivalence(self): config_inputs_dict = self.prepare_config_and_inputs() + labels = config_inputs_dict.pop("decoder_token_labels") + # Keep only common arguments arg_names = [ "config", @@ -454,6 +467,9 @@ class TFEncoderDecoderMixin: # `encoder_hidden_states` is not used in model call/forward del inputs_dict["encoder_hidden_states"] + inputs_dict_with_labels = copy.copy(inputs_dict) + inputs_dict_with_labels["labels"] = labels + # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) batch_size = inputs_dict["decoder_attention_mask"].shape[0] inputs_dict["decoder_attention_mask"] = tf.constant( @@ -471,6 +487,10 @@ class TFEncoderDecoderMixin: self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) + # check equivalence with labels + self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels) + self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels) + # This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`, # which randomly initialize `enc_to_dec_proj`. # # check `enc_to_dec_proj` work as expected diff --git a/tests/test_modeling_tf_vision_encoder_decoder.py b/tests/test_modeling_tf_vision_encoder_decoder.py index ec041786a8..3f1783e034 100644 --- a/tests/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/test_modeling_tf_vision_encoder_decoder.py @@ -15,6 +15,7 @@ """ Testing suite for the TensorFlow VisionEncoderDecoder model. """ +import copy import os import tempfile import unittest @@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin: # prepare inputs tf_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()} + if "labels" in pt_inputs: + pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor) with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - tf_outputs = tf_model(**inputs_dict).to_tuple() + tf_outputs = tf_model(**inputs_dict) + if "loss" in tf_outputs: + tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss) + tf_outputs = tf_outputs.to_tuple() self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch") + for tf_output, pt_output in zip(tf_outputs, pt_outputs): self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3) @@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin: # This is only for copying some specific attributes of this particular model. tf_model_loaded.config = pt_model.config - tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple() + tf_outputs_loaded = tf_model_loaded(**inputs_dict) + if "loss" in tf_outputs_loaded: + tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss) + tf_outputs_loaded = tf_outputs_loaded.to_tuple() self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch") + for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs): self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3) @@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin: def test_pt_tf_equivalence(self): config_inputs_dict = self.prepare_config_and_inputs() + labels = config_inputs_dict.pop("decoder_token_labels") + # Keep only common arguments arg_names = [ "config", @@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin: # `encoder_hidden_states` is not used in model call/forward del inputs_dict["encoder_hidden_states"] + inputs_dict_with_labels = copy.copy(inputs_dict) + inputs_dict_with_labels["labels"] = labels + # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) batch_size = inputs_dict["decoder_attention_mask"].shape[0] inputs_dict["decoder_attention_mask"] = tf.constant( @@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin: self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) + # check equivalence with labels + self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels) + self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels) + # This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`, # which randomly initialize `enc_to_dec_proj`. # # check `enc_to_dec_proj` work as expected @@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te "decoder_config": decoder_config, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, + "decoder_token_labels": decoder_token_labels, "encoder_hidden_states": encoder_hidden_states, # This is not used in the tests. "labels": decoder_token_labels, }