From deff792bb6d0a099ba681d8513da0792f23162b4 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 18:50:39 +0100 Subject: [PATCH] add prepare inputs for transfo_xl and xlnet --- src/transformers/modeling_transfo_xl.py | 9 +++++++++ src/transformers/modeling_utils.py | 14 +++++++------- src/transformers/modeling_xlnet.py | 8 +++++++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 3589a3d87d..938ee86ec3 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): return self.out_layer else: return self.crit.out_layers[-1] + + def prepare_inputs_for_generation(self, input_ids, **model_kwargs): + inputs = {"input_ids": input_ids} + + # if past is defined in model kwargs then use it for faster decoding + if 'past' in model_kwargs and model_kwargs['past']: + inputs['mems'] = model_kwargs['past'] + + return inputs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 437ec8f6f0..3e24b2b359 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module): return {"input_ids": input_ids} def _do_output_past(self, outputs): - # TODO: might be better to write a self.do_output_past method for each - # individual class as is done for prepare_inputs_for_generation has_output_past = hasattr(self.config, 'output_past') and self.config.output_past - has_multiple_outputs = len(outputs) > 1 - has_mem_len = hasattr(self.config, 'mem_len') + has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len - if has_output_past and has_multiple_outputs and not has_mem_len: + if has_output_past and not has_mem_len and len(outputs) > 1: return True - # TODO: Add cases for (xlnet, transfo_xl) using mem_len + elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1: + return True + return False @torch.no_grad() @@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module): if past: reordered_past = [] for layer_past in past: - # copy the relevant beam idx past to past + # get the correct batch idx from layer past batch dim + # batch dim of `past` and `mems` is at 2nd position reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] reordered_layer_past = torch.cat(reordered_layer_past, dim=1) # check that shape matches diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index be9c41b0e5..dc38821058 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ) target_mapping[0, 0, -1] = 1.0 - return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} + inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} + + # if past is defined in model kwargs then use it for faster decoding + if 'past' in model_kwargs and model_kwargs['past']: + inputs['mems'] = model_kwargs['past'] + + return inputs def forward( self,