This commit is contained in:
Julien Plu
2021-01-04 15:56:51 +01:00
committed by GitHub
parent c581d8af7a
commit 6c03d4ac70

View File

@@ -375,7 +375,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states:
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
hidden_states,
@@ -384,7 +384,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs["attention_mask"],
inputs["head_mask"][i],
inputs["use_cache"],
output_attentions,
inputs["output_attentions"],
training=inputs["training"],
)
hidden_states, present = outputs[:2]
@@ -392,7 +392,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if inputs["use_cache"]:
presents = presents + (present,)
if output_attentions:
if inputs["output_attentions"]:
all_attentions = all_attentions + (outputs[2],)
hidden_states = self.layernorm(hidden_states)