From 6fe1cc08743ee8f8b76c622e5a0742413048e857 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Wed, 4 Mar 2020 23:24:13 +0000 Subject: [PATCH] fix: clean up inadvertent change in tf_t5 This was the beginnings of an attempt to address the test failure on this layer, and instead I backed out of making this layer keras-serializable at all ... so it was a mistake to commit this. --- src/transformers/modeling_tf_t5.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 41520e044b..db62e784b1 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -383,21 +383,14 @@ class TFT5MainLayer(tf.keras.layers.Layer): def call( self, - inputs, + hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, training=False, ): - if isinstance(inputs, (tuple, list)): - hidden_states = inputs[0] - assert len(inputs) <= 1, "Too many inputs." - elif isinstance(inputs, dict): - hidden_states = inputs["hidden_states"] - assert len(inputs) <= 1, "Too many inputs." - else: - hidden_states = inputs + batch_size, seq_length = shape_list(hidden_states)[:2] if attention_mask is None: attention_mask = tf.fill((batch_size, seq_length), 1)