TF generate refactor - past without encoder outputs (#15944)
* Remove packed past from generation_tf_utils * update models with the new past format * update template accordingly
This commit is contained in:
@@ -1777,7 +1777,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
|
||||
{% else %}
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -2736,9 +2736,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@@ -3186,43 +3183,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=False,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@@ -3233,17 +3210,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
|
||||
@@ -802,7 +802,6 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
Reference in New Issue
Block a user