Replace past with past_key_values (#20944)
* start cleanup * more updates * more models are affected * more updates * update generation utils * style * revert change that removed reorder cachce * update generation utils * style * style * remove reorder cache
This commit is contained in:
@@ -1121,15 +1121,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||
return self.mlm.predictions
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past:
|
||||
if past_key_values:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": model_kwargs["use_cache"],
|
||||
}
|
||||
|
||||
@@ -3003,7 +3003,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
@@ -3013,13 +3013,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
|
||||
@@ -1167,7 +1167,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
@@ -1175,10 +1175,10 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
@@ -2879,7 +2879,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
@@ -2889,13 +2889,13 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@@ -3328,7 +3328,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
@@ -3339,7 +3339,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user