add prepare inputs for transfo_xl and xlnet
This commit is contained in:
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
return self.out_layer
|
||||
else:
|
||||
return self.crit.out_layers[-1]
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||
inputs = {"input_ids": input_ids}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if 'past' in model_kwargs and model_kwargs['past']:
|
||||
inputs['mems'] = model_kwargs['past']
|
||||
|
||||
return inputs
|
||||
|
||||
Reference in New Issue
Block a user