Patch-past-refactor (#21050)
* small patches, forgot a line * refactor PT * the actual fix
This commit is contained in:
@@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user