make if statements cleaner for prepare_inputs_for_generation

This commit is contained in:
patrickvonplaten
2019-12-23 23:55:05 +01:00
parent d039c679d2
commit 365ccd0af2
3 changed files with 8 additions and 6 deletions

View File

@@ -491,8 +491,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids should only be composed of last token if past is in kwargs and defined
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
# only last token for inputs_ids if past is defined in kwargs
if 'past' in kwargs and kwargs['past']:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)