add prepare inputs for transfo_xl and xlnet

This commit is contained in:
patrickvonplaten
2019-12-25 18:50:39 +01:00
parent 9398058e19
commit deff792bb6
3 changed files with 23 additions and 8 deletions

View File

@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.out_layer
else:
return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
inputs = {"input_ids": input_ids}
# if past is defined in model kwargs then use it for faster decoding
if 'past' in model_kwargs and model_kwargs['past']:
inputs['mems'] = model_kwargs['past']
return inputs