fix: clean up inadvertent change in tf_t5
This was the beginnings of an attempt to address the test failure on this layer, and instead I backed out of making this layer keras-serializable at all ... so it was a mistake to commit this.
This commit is contained in:
@@ -383,21 +383,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
inputs,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
if isinstance(inputs, (tuple, list)):
|
|
||||||
hidden_states = inputs[0]
|
|
||||||
assert len(inputs) <= 1, "Too many inputs."
|
|
||||||
elif isinstance(inputs, dict):
|
|
||||||
hidden_states = inputs["hidden_states"]
|
|
||||||
assert len(inputs) <= 1, "Too many inputs."
|
|
||||||
else:
|
|
||||||
hidden_states = inputs
|
|
||||||
batch_size, seq_length = shape_list(hidden_states)[:2]
|
batch_size, seq_length = shape_list(hidden_states)[:2]
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.fill((batch_size, seq_length), 1)
|
attention_mask = tf.fill((batch_size, seq_length), 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user