@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
|
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
|
||||||
len(past_key_value)
|
len(past_key_value)
|
||||||
)
|
)
|
||||||
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
|
real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length
|
||||||
|
|
||||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1]
|
||||||
|
|
||||||
def shape(hidden_states):
|
def shape(hidden_states):
|
||||||
""" projection """
|
""" projection """
|
||||||
@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
|
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||||
if past is not None:
|
if past is not None:
|
||||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||||
return decoder_outputs + inputs["encoder_outputs"]
|
return decoder_outputs + inputs["encoder_outputs"]
|
||||||
|
|
||||||
|
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
|
||||||
|
|
||||||
return TFSeq2SeqModelOutput(
|
return TFSeq2SeqModelOutput(
|
||||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
|
|
||||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||||
|
|
||||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
|
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||||
if past is not None:
|
if past is not None:
|
||||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||||
@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
attentions=attentions,
|
attentions=attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
|
||||||
|
|
||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
|||||||
Reference in New Issue
Block a user