From 64103fb6beac8cc865945d3956266fd80b44f18f Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 28 Dec 2020 20:52:18 +0100 Subject: [PATCH] Fix TransfoXL (#9302) --- src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py index 8bb445e5d0..555c5ea09a 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py @@ -501,8 +501,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): # There are `mlen + qlen` steps that can be cached into mems new_mems = [] - end_idx = mlen + max(0, qlen) - beg_idx = max(0, end_idx - self.mem_len) + end_idx = mlen + tf.math.maximum(0, qlen) + beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len)) for i in range(len(hids)): cat = tf.concat([mems[i], hids[i]], axis=0)