Patch-past-refactor (#21050)

* small patches, forgot a line

* refactor PT

* the actual fix
This commit is contained in:
Arthur
2023-01-09 18:12:13 +01:00
committed by GitHub
parent 48d4e147d8
commit e3ecbaa4ab
4 changed files with 5 additions and 8 deletions

View File

@@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {