From 1fc4b2a13223b9069f9969344117a2994261939c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 22 Jul 2022 13:31:45 +0100 Subject: [PATCH] TF: use the correct config with `(...)EncoderDecoder` models (#18097) --- src/transformers/modeling_tf_utils.py | 30 +++-- .../modeling_tf_encoder_decoder.py | 14 +-- .../modeling_tf_vision_encoder_decoder.py | 14 +-- .../test_modeling_encoder_decoder.py | 110 ++++++++++++++---- .../test_modeling_tf_encoder_decoder.py | 107 +++++++++++++---- 5 files changed, 200 insertions(+), 75 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 1806e43655..ddb24d7e3b 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -403,8 +403,13 @@ def unpack_inputs(func): # move any arg into kwargs, if they exist fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) - # process the inputs and call the wrapped function - unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs) + # Encoder Decoder models delegate the application of the configuration options to their inner models. + if "encoder_decoder" in str(self).lower(): + config = None + else: + config = self.config + + unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs) return func(self, **unpacked_inputs) # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This @@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs): if "kwargs" in output: del output["kwargs"] - boolean_dict = { - k: v - for k, v in output.items() - if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] - } + if config is not None: + boolean_dict = { + k: v + for k, v in output.items() + if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] + } - output.update( - booleans_processing( - config=config, - **boolean_dict, + output.update( + booleans_processing( + config=config, + **boolean_dict, + ) ) - ) return output diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 5c74e8433e..714e2c231d 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): warnings.warn(DEPRECATION_WARNING, FutureWarning) loss = self.hf_compute_loss(labels, logits) - past_key_values = None - if decoder_inputs["use_cache"]: - past_key_values = decoder_outputs[1] - # The starting index of the remaining elements in `decoder_outputs` - start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) + if not return_dict: + past_key_values = None + if use_cache: + past_key_values = decoder_outputs[1] + # The starting index of the remaining elements in `decoder_outputs` + start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - if not decoder_inputs["return_dict"]: if not isinstance(encoder_outputs, tuple): encoder_outputs = encoder_outputs.to_tuple() output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs @@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): return TFSeq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, - past_key_values=past_key_values, + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index ba65525ae0..682faa3825 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos warnings.warn(DEPRECATION_WARNING, FutureWarning) loss = self.hf_compute_loss(labels, logits) - past_key_values = None - if decoder_inputs["use_cache"]: - past_key_values = decoder_outputs[1] - # The starting index of the remaining elements in `decoder_outputs` - start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) + if not return_dict: + past_key_values = None + if use_cache: + past_key_values = decoder_outputs[1] + # The starting index of the remaining elements in `decoder_outputs` + start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - if not decoder_inputs["return_dict"]: if not isinstance(encoder_outputs, tuple): encoder_outputs = encoder_outputs.to_tuple() output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs @@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos return TFSeq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, - past_key_values=past_key_values, + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index b356b3ee0b..6980ed6cb2 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -351,32 +351,9 @@ class EncoderDecoderMixin: outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) - def check_encoder_decoder_model_output_attentions( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - labels, - **kwargs + def _check_output_with_attentions( + self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids ): - # make the decoder inputs a different shape from the encoder inputs to harden the test - decoder_input_ids = decoder_input_ids[:, :-1] - decoder_attention_mask = decoder_attention_mask[:, :-1] - encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) - enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - enc_dec_model.to(torch_device) - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - output_attentions=True, - ) - encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) @@ -408,6 +385,85 @@ class EncoderDecoderMixin: (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), ) + def check_encoder_decoder_model_output_attentions( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + ) + self._check_output_with_attentions( + outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids + ) + + def check_encoder_decoder_model_output_attentions_from_config( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + # Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the + # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded + # from the inner models' configurations. + + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.config.output_attentions = True # model config -> won't work + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + self.assertTrue( + all( + key not in outputs_encoder_decoder + for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"] + ) + ) + + config.output_attentions = True # inner model config -> will work + decoder_config.output_attentions = True + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + self._check_output_with_attentions( + outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids + ) + def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) @@ -543,6 +599,10 @@ class EncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_output_attentions(**input_ids_dict) + def test_encoder_decoder_model_output_attentions_from_config(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict) + def test_encoder_decoder_model_generate(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) diff --git a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py index d2e9989457..d179d5f9d5 100644 --- a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -255,31 +255,9 @@ class TFEncoderDecoderMixin: outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) - def check_encoder_decoder_model_output_attentions( - self, - config, - input_ids, - attention_mask, - encoder_hidden_states, - decoder_config, - decoder_input_ids, - decoder_attention_mask, - **kwargs + def _check_output_with_attentions( + self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids ): - # make the decoder inputs a different shape from the encoder inputs to harden the test - decoder_input_ids = decoder_input_ids[:, :-1] - decoder_attention_mask = decoder_attention_mask[:, :-1] - encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) - enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - outputs_encoder_decoder = enc_dec_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - output_attentions=True, - kwargs=kwargs, - ) - encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) @@ -311,6 +289,83 @@ class TFEncoderDecoderMixin: (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), ) + def check_encoder_decoder_model_output_attentions( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + kwargs=kwargs, + ) + self._check_output_with_attentions( + outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids + ) + + def check_encoder_decoder_model_output_attentions_from_config( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + # Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the + # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded + # from the inner models' configurations. + + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.config.output_attentions = True # model config -> won't work + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, + ) + self.assertTrue( + all( + key not in outputs_encoder_decoder + for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"] + ) + ) + + config.output_attentions = True # inner model config -> will work + decoder_config.output_attentions = True + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, + ) + self._check_output_with_attentions( + outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids + ) + def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) @@ -570,6 +625,10 @@ class TFEncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_output_attentions(**input_ids_dict) + def test_encoder_decoder_model_output_attentions_from_config(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict) + def test_encoder_decoder_model_generate(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict)