From 3f77c26d74e1282955fefa8dfff2451e44f6d4a9 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 3 Feb 2021 12:26:32 +0100 Subject: [PATCH] Fix Longformer and LED (#9942) * Fix Longformer and LED * Add a test for graph execution with inputs_embeds * Apply style --- .../models/led/modeling_tf_led.py | 21 +++++++------- .../longformer/modeling_tf_longformer.py | 4 +-- tests/test_modeling_tf_common.py | 29 +++++++++++++++++++ 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 9896d0bc2e..ba068869eb 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -1665,7 +1665,6 @@ class TFLEDEncoder(tf.keras.layers.Layer): def compute_hidden_states(self, hidden_states, padding_len): return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states - @tf.function def _pad_to_window_size( self, input_ids, @@ -1685,26 +1684,28 @@ class TFLEDEncoder(tf.keras.layers.Layer): batch_size, seq_len = input_shape[:2] padding_len = (attention_window - seq_len % attention_window) % attention_window - if padding_len > 0: + if tf.math.greater(padding_len, 0): logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( seq_len, seq_len + padding_len, attention_window ) ) - paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) + paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) - if input_ids is not None: - input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) + if input_ids is not None: + input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) - if inputs_embeds is not None: + if inputs_embeds is not None: + + def pad_embeddings(): input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id) inputs_embeds_padding = self.embed_tokens(input_ids_padding) - inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) + return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - attention_mask = tf.pad( - attention_mask, paddings, constant_values=False - ) # no attention on the padding tokens + inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds) + + attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens return ( padding_len, diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 20480f083c..e3c747c939 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -1836,7 +1836,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): batch_size, seq_len = input_shape[:2] padding_len = (attention_window - seq_len % attention_window) % attention_window - if padding_len > 0: + if tf.math.greater(padding_len, 0): logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( seq_len, seq_len + padding_len, attention_window @@ -1859,7 +1859,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): inputs_embeds_padding = self.embeddings(input_ids_padding) return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds) + inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds) attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 1bab898c8f..c685b7a56f 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -884,6 +884,35 @@ class TFModelTesterMixin: model(inputs) + def test_graph_mode_with_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids) + else: + inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids) + inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids) + + @tf.function + def run_in_graph_mode(): + return model(inputs) + + outputs = run_in_graph_mode() + self.assertIsNotNone(outputs) + def test_numpy_arrays_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()