From 7a9ef8181c9bb92385eccbb2d35e864fa80fadf2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 29 Mar 2022 18:17:47 +0100 Subject: [PATCH] TF: properly handle kwargs in encoder_decoder architectures (#16465) * properly handle kwargs in encoder_decoder architectures * make fixup --- .../models/encoder_decoder/modeling_tf_encoder_decoder.py | 6 ++---- .../modeling_tf_vision_encoder_decoder.py | 6 ++---- tests/encoder_decoder/test_modeling_tf_encoder_decoder.py | 8 ++++++++ .../test_modeling_tf_vision_encoder_decoder.py | 8 ++++++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index c2be91c7a0..1c59493e1b 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): "output_hidden_states": output_hidden_states, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_encoder, + "kwargs_call": {}, } # Add arguments to encoder from `kwargs_encoder` for k, v in kwargs_encoder.items(): encoder_processing_inputs[k] = v - kwargs_encoder = {} encoder_inputs = input_processing(**encoder_processing_inputs) @@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): "past_key_values": past_key_values, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_decoder, + "kwargs_call": {}, } # Add arguments to decoder from `kwargs_decoder` for k, v in kwargs_decoder.items(): decoder_processing_inputs[k] = v - kwargs_decoder = {} decoder_inputs = input_processing(**decoder_processing_inputs) decoder_outputs = self.decoder(**decoder_inputs) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 965fc51d78..eeaca58c5a 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos "output_hidden_states": output_hidden_states, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_encoder, + "kwargs_call": {}, } # Add arguments to encoder from `kwargs_encoder` encoder_processing_inputs.update(kwargs_encoder) - kwargs_encoder = {} encoder_inputs = input_processing(**encoder_processing_inputs) @@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos "past_key_values": past_key_values, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_decoder, + "kwargs_call": {}, } # Add arguments to decoder from `kwargs_decoder` decoder_processing_inputs.update(kwargs_decoder) - kwargs_decoder = {} decoder_inputs = input_processing(**decoder_processing_inputs) decoder_outputs = self.decoder(**decoder_inputs) diff --git a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py index edcc881f56..de903c40c2 100644 --- a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -91,6 +91,7 @@ class TFEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -122,6 +123,7 @@ class TFEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) @@ -137,6 +139,7 @@ class TFEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -167,6 +170,7 @@ class TFEncoderDecoderMixin: attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, + kwargs=kwargs, ) self.assertEqual( @@ -195,6 +199,7 @@ class TFEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 @@ -208,6 +213,7 @@ class TFEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 @@ -235,6 +241,7 @@ class TFEncoderDecoderMixin: attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, labels=labels, + kwargs=kwargs, ) # Make sure `loss` exist @@ -269,6 +276,7 @@ class TFEncoderDecoderMixin: attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_attentions=True, + kwargs=kwargs, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"] diff --git a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index a0fcbfaea3..f3a062744f 100644 --- a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin: pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin: pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) @@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin: encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, return_dict=True, + kwargs=kwargs, ) self.assertEqual( @@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin: pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 @@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin: pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 @@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, + kwargs=kwargs, ) # Make sure `loss` exist @@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin: decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, output_attentions=True, + kwargs=kwargs, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"]