@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
|
||||
len(past_key_value)
|
||||
)
|
||||
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
|
||||
real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length
|
||||
|
||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
||||
key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1]
|
||||
|
||||
def shape(hidden_states):
|
||||
""" projection """
|
||||
@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
return decoder_outputs + inputs["encoder_outputs"]
|
||||
|
||||
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=past,
|
||||
@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
if not inputs["return_dict"]:
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||
@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
attentions=attentions,
|
||||
)
|
||||
|
||||
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
||||
Reference in New Issue
Block a user