add past hidden key states for more efficient language generation & add prepare_inputs for gpt2 and ctrl model

This commit is contained in:
patrickvonplaten
2019-12-23 21:19:27 +01:00
parent aeef4823ab
commit d891fd0ae0
3 changed files with 38 additions and 7 deletions

View File

@@ -490,6 +490,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids contain only last token if past is in kwargs and defined
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward(
self,
input_ids=None,