Fix TransfoXL (#9302)
This commit is contained in:
@@ -501,8 +501,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# There are `mlen + qlen` steps that can be cached into mems
|
||||
new_mems = []
|
||||
end_idx = mlen + max(0, qlen)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
end_idx = mlen + tf.math.maximum(0, qlen)
|
||||
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
|
||||
for i in range(len(hids)):
|
||||
|
||||
cat = tf.concat([mems[i], hids[i]], axis=0)
|
||||
|
||||
Reference in New Issue
Block a user