From 07708793f20ec3a949ccab32cc4fe0c7272dcc4c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 6 Nov 2020 21:03:25 +0100 Subject: [PATCH] fix encoder outputs (#8368) --- src/transformers/generation_tf_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 71d8c6deb9..d61ee8f673 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -348,8 +348,7 @@ class TFGenerationMixin: shape=(-1,), ) # expand encoder_outputs - encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:]) - + encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),) else: encoder_outputs = None cur_len = shape_list(input_ids)[-1]