[CI ] Remove past in favor of pat_key_values (#21443)
* fix past renamed to past_key_value * update more `past`that were ski^êd * fixup * remove changes made to rag * refactor `_reorder_cache` to use `past_key_values` * fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
This commit is contained in:
@@ -1180,9 +1180,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],)
|
||||
return reordered_past
|
||||
|
||||
@@ -2905,9 +2905,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
@@ -3344,9 +3344,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
{% endif -%}
|
||||
|
||||
Reference in New Issue
Block a user