Fix TransfoXL (#9302)

This commit is contained in:
Julien Plu
2020-12-28 20:52:18 +01:00
committed by GitHub
parent d97d06d05f
commit 64103fb6be

View File

@@ -501,8 +501,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
# There are `mlen + qlen` steps that can be cached into mems # There are `mlen + qlen` steps that can be cached into mems
new_mems = [] new_mems = []
end_idx = mlen + max(0, qlen) end_idx = mlen + tf.math.maximum(0, qlen)
beg_idx = max(0, end_idx - self.mem_len) beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
for i in range(len(hids)): for i in range(len(hids)):
cat = tf.concat([mems[i], hids[i]], axis=0) cat = tf.concat([mems[i], hids[i]], axis=0)