fix: clean up inadvertent change in tf_t5

This was the beginnings of an attempt to address the test failure on
this layer, and instead I backed out of making this layer
keras-serializable at all ... so it was a mistake to commit this.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-04 23:24:13 +00:00
parent 18f4b9274f
commit 6fe1cc0874

View File

@@ -383,21 +383,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, hidden_states,
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)):
hidden_states = inputs[0]
assert len(inputs) <= 1, "Too many inputs."
elif isinstance(inputs, dict):
hidden_states = inputs["hidden_states"]
assert len(inputs) <= 1, "Too many inputs."
else:
hidden_states = inputs
batch_size, seq_length = shape_list(hidden_states)[:2] batch_size, seq_length = shape_list(hidden_states)[:2]
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill((batch_size, seq_length), 1) attention_mask = tf.fill((batch_size, seq_length), 1)