Fix CTRL (#9291)
This commit is contained in:
@@ -375,7 +375,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_attentions = () if inputs["output_attentions"] else None
|
all_attentions = () if inputs["output_attentions"] else None
|
||||||
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
||||||
if output_hidden_states:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||||
outputs = h(
|
outputs = h(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -384,7 +384,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
inputs["attention_mask"],
|
inputs["attention_mask"],
|
||||||
inputs["head_mask"][i],
|
inputs["head_mask"][i],
|
||||||
inputs["use_cache"],
|
inputs["use_cache"],
|
||||||
output_attentions,
|
inputs["output_attentions"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
hidden_states, present = outputs[:2]
|
hidden_states, present = outputs[:2]
|
||||||
@@ -392,7 +392,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
if inputs["use_cache"]:
|
if inputs["use_cache"]:
|
||||||
presents = presents + (present,)
|
presents = presents + (present,)
|
||||||
|
|
||||||
if output_attentions:
|
if inputs["output_attentions"]:
|
||||||
all_attentions = all_attentions + (outputs[2],)
|
all_attentions = all_attentions + (outputs[2],)
|
||||||
|
|
||||||
hidden_states = self.layernorm(hidden_states)
|
hidden_states = self.layernorm(hidden_states)
|
||||||
|
|||||||
Reference in New Issue
Block a user