From d97d06d05f3349f81716268df244d45b037518ef Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 28 Dec 2020 20:51:40 +0100 Subject: [PATCH] Fix TF T5 (#9301) * Fix T5 * Fix test * Fix test --- src/transformers/models/t5/modeling_tf_t5.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index f10cbf5f5d..437517ee49 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer): ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( len(past_key_value) ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1] def shape(hidden_states): """ projection """ @@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel): training=inputs["training"], ) - past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None - if not inputs["return_dict"]: + past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None if past is not None: decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] return decoder_outputs + inputs["encoder_outputs"] + past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None + return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=past, @@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) - past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None if not inputs["return_dict"]: + past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None if past is not None: decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] @@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling attentions=attentions, ) + past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None + return TFSeq2SeqLMOutput( loss=loss, logits=logits,