fix encoder outputs (#8368)

This commit is contained in:
Patrick von Platen
2020-11-06 21:03:25 +01:00
committed by GitHub
parent bc0d26d1de
commit 07708793f2

View File

@@ -348,8 +348,7 @@ class TFGenerationMixin:
shape=(-1,), shape=(-1,),
) )
# expand encoder_outputs # expand encoder_outputs
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:]) encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),)
else: else:
encoder_outputs = None encoder_outputs = None
cur_len = shape_list(input_ids)[-1] cur_len = shape_list(input_ids)[-1]