From d891fd0ae0d43ae71723c1df9622ffef33ce84aa Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 21:19:27 +0100 Subject: [PATCH 01/12] add past hidden key states for more efficient language generation & add prepare_inputs for gpt2 and ctrl model --- src/transformers/modeling_ctrl.py | 8 ++++++++ src/transformers/modeling_gpt2.py | 8 ++++++++ src/transformers/modeling_utils.py | 29 ++++++++++++++++++++++------- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index b4906b7aa6..69e3a1a8da 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -490,6 +490,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def get_output_embeddings(self): return self.lm_head + def prepare_inputs_for_generation(self, input_ids, **kwargs): + # inputs_ids contain only last token if past is in kwargs and defined + input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + + inputs = {"input_ids": input_ids} + inputs.update(kwargs) + return inputs + def forward( self, input_ids=None, diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index c8d5040f2f..d962259d28 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -559,6 +559,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head + def prepare_inputs_for_generation(self, input_ids, **kwargs): + # inputs_ids contain only last token if past is in kwargs and defined + input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + + inputs = {"input_ids": input_ids} + inputs.update(kwargs) + return inputs + def forward( self, input_ids=None, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8722d578fd..2c8d30e85f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -18,6 +18,7 @@ import logging import os +import ipdb import torch from torch import nn @@ -539,6 +540,14 @@ class PreTrainedModel(nn.Module): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} + def _has_past(self, outputs): + # TODO: might be better to write a self.has_past method for each individual class as is done for + # prepare_inputs_for_generation + if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1: + return True + # TODO: Add cases for (xlnet, transfo_xl) using mem_len + return False + @torch.no_grad() def generate( self, @@ -716,14 +725,16 @@ class PreTrainedModel(nn.Module): # current position / max lengths / length of generated sentences / unfinished sentences unfinished_sents = input_ids.new(batch_size).fill_(1) - # TODO: add cached compute states - pasts = None + past = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] + if self._has_past(outputs): + past = outputs[1] + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: for i in range(batch_size): @@ -782,6 +793,7 @@ class PreTrainedModel(nn.Module): ): """ Generate sequences for each example with beam search. """ + ipdb.set_trace() # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) @@ -797,15 +809,18 @@ class PreTrainedModel(nn.Module): beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states - pasts = None # self.prepare_pasts() + past = None # done sentences done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) - scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) - scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) + outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) + scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + + if self._has_past(outputs): + past = outputs[1] # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: From 267587c258f1972e3743695c61c7f369a47d9a90 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 22:04:42 +0100 Subject: [PATCH 02/12] add and improve comments --- src/transformers/modeling_ctrl.py | 2 +- src/transformers/modeling_gpt2.py | 2 +- src/transformers/modeling_utils.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index 69e3a1a8da..a0b7cedb3b 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -491,7 +491,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids contain only last token if past is in kwargs and defined + # inputs_ids should only be composed of last token if past is in kwargs and defined input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index d962259d28..98581c670e 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -560,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids contain only last token if past is in kwargs and defined + # inputs_ids should only be composed of last token if past is in kwargs and defined input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2c8d30e85f..bfd2be8220 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -732,6 +732,7 @@ class PreTrainedModel(nn.Module): outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] + # if model has past, then set the past parameter to speed up decoding if self._has_past(outputs): past = outputs[1] @@ -819,6 +820,7 @@ class PreTrainedModel(nn.Module): outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + # if model has past, then set the past parameter to speed up decoding if self._has_past(outputs): past = outputs[1] From 7bb42712916d2f82ada62a4195fd70a7c28b94de Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 22:10:35 +0100 Subject: [PATCH 03/12] remove ipdb debugging statements --- src/transformers/modeling_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bfd2be8220..54283617e0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -18,7 +18,6 @@ import logging import os -import ipdb import torch from torch import nn @@ -794,7 +793,6 @@ class PreTrainedModel(nn.Module): ): """ Generate sequences for each example with beam search. """ - ipdb.set_trace() # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) From eeaa402cd4b3cd84e525fbdd158d525650e3a47f Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 22:15:06 +0100 Subject: [PATCH 04/12] rename comments --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 54283617e0..9a07e56368 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -731,7 +731,7 @@ class PreTrainedModel(nn.Module): outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] - # if model has past, then set the past parameter to speed up decoding + # if model has past, then set the past variable to speed up decoding if self._has_past(outputs): past = outputs[1] @@ -818,7 +818,7 @@ class PreTrainedModel(nn.Module): outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) - # if model has past, then set the past parameter to speed up decoding + # if model has past, then set the past variable to speed up decoding if self._has_past(outputs): past = outputs[1] From 7e0c5c731a0f0178d09c9827f486d68ac9ca9848 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 22:33:45 +0100 Subject: [PATCH 05/12] changed do_output_past function to check for self.config.output_past instead of self.output_past --- src/transformers/modeling_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9a07e56368..c727bb5f2e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} - def _has_past(self, outputs): - # TODO: might be better to write a self.has_past method for each individual class as is done for + def _do_output_past(self, outputs): + # TODO: might be better to write a self.do_output_past method for each individual class as is done for # prepare_inputs_for_generation - if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1: + if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'): return True # TODO: Add cases for (xlnet, transfo_xl) using mem_len return False @@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module): next_token_logits = outputs[0][:, -1, :] # if model has past, then set the past variable to speed up decoding - if self._has_past(outputs): + if self._do_output_past(outputs): past = outputs[1] # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) @@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module): scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) # if model has past, then set the past variable to speed up decoding - if self._has_past(outputs): + if self._do_output_past(outputs): past = outputs[1] # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) From d039c679d21ff38182b6c0d18757682f5f50d2aa Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 23:39:16 +0100 Subject: [PATCH 06/12] better naming for if statement --- src/transformers/modeling_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c727bb5f2e..e5e4926af9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -542,7 +542,11 @@ class PreTrainedModel(nn.Module): def _do_output_past(self, outputs): # TODO: might be better to write a self.do_output_past method for each individual class as is done for # prepare_inputs_for_generation - if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'): + has_output_past = hasattr(self.config, 'output_past') and self.config.output_past + has_multiple_outputs = len(outputs) > 1 + has_mem_len = hasattr(self, 'mem_len') + + if has_output_past and has_multiple_outputs and not has_mem_len: return True # TODO: Add cases for (xlnet, transfo_xl) using mem_len return False From 365ccd0af20586aac9ca5312995584981db4dae4 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 23:55:05 +0100 Subject: [PATCH 07/12] make if statements cleaner for prepare_inputs_for_generation --- src/transformers/modeling_ctrl.py | 5 +++-- src/transformers/modeling_gpt2.py | 5 +++-- src/transformers/modeling_utils.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index a0b7cedb3b..91cf62b3b5 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -491,8 +491,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids should only be composed of last token if past is in kwargs and defined - input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + # only last token for inputs_ids if past is defined in kwargs + if 'past' in kwargs and kwargs['past']: + input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} inputs.update(kwargs) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 98581c670e..6e9b5066e9 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -560,8 +560,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids should only be composed of last token if past is in kwargs and defined - input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + # only last token for inputs_ids if past is defined in kwargs + if 'past' in kwargs and kwargs['past']: + input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} inputs.update(kwargs) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e5e4926af9..3248763bdb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module): return {"input_ids": input_ids} def _do_output_past(self, outputs): - # TODO: might be better to write a self.do_output_past method for each individual class as is done for - # prepare_inputs_for_generation + # TODO: might be better to write a self.do_output_past method for each + # individual class as is done for prepare_inputs_for_generation has_output_past = hasattr(self.config, 'output_past') and self.config.output_past has_multiple_outputs = len(outputs) > 1 has_mem_len = hasattr(self, 'mem_len') From 6bca56fdb0587a4291f8465a0a6e818f5541a5e3 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 24 Dec 2019 01:02:58 +0100 Subject: [PATCH 08/12] check for self.config.mem_len instead of self.mem_len in _do_output_past --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3248763bdb..f81bcbecae 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -544,7 +544,7 @@ class PreTrainedModel(nn.Module): # individual class as is done for prepare_inputs_for_generation has_output_past = hasattr(self.config, 'output_past') and self.config.output_past has_multiple_outputs = len(outputs) > 1 - has_mem_len = hasattr(self, 'mem_len') + has_mem_len = hasattr(self.config, 'mem_len') if has_output_past and has_multiple_outputs and not has_mem_len: return True From 90cda45e9e7da95f9084ceca6d631f64173b69c8 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 16:29:20 +0100 Subject: [PATCH 09/12] add past re-ordering for beam search --- src/transformers/modeling_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f81bcbecae..c0eaec9c2c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module): beam_words = input_ids.new([x[1] for x in next_batch_beam]) beam_idx = input_ids.new([x[2] for x in next_batch_beam]) - # re-order batch and internal states + # re-order batch input_ids = input_ids[beam_idx, :] input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1) - # TODO: Activate cache - # for k in cache.keys(): - # if k != 'slen': - # cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) + + # re-order internal states + if past: + reordered_past = [] + for layer_past in past: + # copy the relevant beam idx past to past + reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] + reordered_past.append(torch.cat(reordered_layer_past, dim=1)) + past = tuple(reordered_past) # update current length cur_len = cur_len + 1 From 9398058e19d1ca89c881890b6dad72e384cc88c6 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 16:34:28 +0100 Subject: [PATCH 10/12] add easy tensor shape match test --- src/transformers/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c0eaec9c2c..437ec8f6f0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module): for layer_past in past: # copy the relevant beam idx past to past reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] - reordered_past.append(torch.cat(reordered_layer_past, dim=1)) + reordered_layer_past = torch.cat(reordered_layer_past, dim=1) + # check that shape matches + assert reordered_layer_past.shape == layer_past.shape + reordered_past.append(reordered_layer_past) past = tuple(reordered_past) # update current length From deff792bb6d0a099ba681d8513da0792f23162b4 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 18:50:39 +0100 Subject: [PATCH 11/12] add prepare inputs for transfo_xl and xlnet --- src/transformers/modeling_transfo_xl.py | 9 +++++++++ src/transformers/modeling_utils.py | 14 +++++++------- src/transformers/modeling_xlnet.py | 8 +++++++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 3589a3d87d..938ee86ec3 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): return self.out_layer else: return self.crit.out_layers[-1] + + def prepare_inputs_for_generation(self, input_ids, **model_kwargs): + inputs = {"input_ids": input_ids} + + # if past is defined in model kwargs then use it for faster decoding + if 'past' in model_kwargs and model_kwargs['past']: + inputs['mems'] = model_kwargs['past'] + + return inputs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 437ec8f6f0..3e24b2b359 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module): return {"input_ids": input_ids} def _do_output_past(self, outputs): - # TODO: might be better to write a self.do_output_past method for each - # individual class as is done for prepare_inputs_for_generation has_output_past = hasattr(self.config, 'output_past') and self.config.output_past - has_multiple_outputs = len(outputs) > 1 - has_mem_len = hasattr(self.config, 'mem_len') + has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len - if has_output_past and has_multiple_outputs and not has_mem_len: + if has_output_past and not has_mem_len and len(outputs) > 1: return True - # TODO: Add cases for (xlnet, transfo_xl) using mem_len + elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1: + return True + return False @torch.no_grad() @@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module): if past: reordered_past = [] for layer_past in past: - # copy the relevant beam idx past to past + # get the correct batch idx from layer past batch dim + # batch dim of `past` and `mems` is at 2nd position reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] reordered_layer_past = torch.cat(reordered_layer_past, dim=1) # check that shape matches diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index be9c41b0e5..dc38821058 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ) target_mapping[0, 0, -1] = 1.0 - return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} + inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} + + # if past is defined in model kwargs then use it for faster decoding + if 'past' in model_kwargs and model_kwargs['past']: + inputs['mems'] = model_kwargs['past'] + + return inputs def forward( self, From fc84bd5254ed0f89f50c1491cc5b68135a8a5125 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 23:32:44 +0100 Subject: [PATCH 12/12] adapt style to predefined style layout --- src/transformers/modeling_ctrl.py | 2 +- src/transformers/modeling_gpt2.py | 2 +- src/transformers/modeling_transfo_xl.py | 4 ++-- src/transformers/modeling_utils.py | 4 ++-- src/transformers/modeling_xlnet.py | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index 91cf62b3b5..d069209a48 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def prepare_inputs_for_generation(self, input_ids, **kwargs): # only last token for inputs_ids if past is defined in kwargs - if 'past' in kwargs and kwargs['past']: + if "past" in kwargs and kwargs["past"]: input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 6e9b5066e9..7f8d1454de 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def prepare_inputs_for_generation(self, input_ids, **kwargs): # only last token for inputs_ids if past is defined in kwargs - if 'past' in kwargs and kwargs['past']: + if "past" in kwargs and kwargs["past"]: input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 938ee86ec3..394e656774 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): inputs = {"input_ids": input_ids} # if past is defined in model kwargs then use it for faster decoding - if 'past' in model_kwargs and model_kwargs['past']: - inputs['mems'] = model_kwargs['past'] + if "past" in model_kwargs and model_kwargs["past"]: + inputs["mems"] = model_kwargs["past"] return inputs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e24b2b359..786f03b9fe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module): return {"input_ids": input_ids} def _do_output_past(self, outputs): - has_output_past = hasattr(self.config, 'output_past') and self.config.output_past - has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len + has_output_past = hasattr(self.config, "output_past") and self.config.output_past + has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len if has_output_past and not has_mem_len and len(outputs) > 1: return True diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index dc38821058..8b00fa7e37 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} # if past is defined in model kwargs then use it for faster decoding - if 'past' in model_kwargs and model_kwargs['past']: - inputs['mems'] = model_kwargs['past'] + if "past" in model_kwargs and model_kwargs["past"]: + inputs["mems"] = model_kwargs["past"] return inputs