TF generate refactor - past without encoder outputs (#15944)

* Remove packed past from generation_tf_utils

* update models with the new past format

* update template accordingly
This commit is contained in:
Joao Gante
2022-03-08 14:46:44 +00:00
committed by GitHub
parent 62d847602a
commit 70203b5937
30 changed files with 301 additions and 684 deletions

View File

@@ -1777,7 +1777,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
{% else %}
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
@@ -2736,9 +2736,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@@ -3186,43 +3183,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=False,
use_cache=None,
encoder_outputs=None,
**kwargs
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@@ -3233,17 +3210,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:],
)
return (past[0], reordered_past)
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""

View File

@@ -802,7 +802,6 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)