fix: clean up inadvertent change in tf_t5
This was the beginnings of an attempt to address the test failure on this layer, and instead I backed out of making this layer keras-serializable at all ... so it was a mistake to commit this.
This commit is contained in:
@@ -383,21 +383,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
training=False,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
hidden_states = inputs[0]
|
||||
assert len(inputs) <= 1, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
hidden_states = inputs["hidden_states"]
|
||||
assert len(inputs) <= 1, "Too many inputs."
|
||||
else:
|
||||
hidden_states = inputs
|
||||
|
||||
batch_size, seq_length = shape_list(hidden_states)[:2]
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill((batch_size, seq_length), 1)
|
||||
|
||||
Reference in New Issue
Block a user