Replace past with past_key_values (#20944)

* start cleanup

* more updates

* more models are affected

* more updates

* update generation utils

* style

* revert change that removed reorder cachce

* update generation utils

* style

* style

* remove reorder cache
This commit is contained in:
Arthur
2023-01-08 10:21:40 +01:00
committed by GitHub
parent 7cb596fa22
commit f0577df6de
84 changed files with 479 additions and 424 deletions

View File

@@ -1121,15 +1121,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, inputs, past_key_values=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
if past_key_values:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": model_kwargs["use_cache"],
}
@@ -3003,7 +3003,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@@ -3013,13 +3013,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,

View File

@@ -1167,7 +1167,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@@ -1175,10 +1175,10 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
@@ -2879,7 +2879,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
@@ -2889,13 +2889,13 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@@ -3328,7 +3328,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
@@ -3339,7 +3339,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
}