From 365ccd0af20586aac9ca5312995584981db4dae4 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 23:55:05 +0100 Subject: [PATCH] make if statements cleaner for prepare_inputs_for_generation --- src/transformers/modeling_ctrl.py | 5 +++-- src/transformers/modeling_gpt2.py | 5 +++-- src/transformers/modeling_utils.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index a0b7cedb3b..91cf62b3b5 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -491,8 +491,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids should only be composed of last token if past is in kwargs and defined - input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + # only last token for inputs_ids if past is defined in kwargs + if 'past' in kwargs and kwargs['past']: + input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} inputs.update(kwargs) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 98581c670e..6e9b5066e9 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -560,8 +560,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids should only be composed of last token if past is in kwargs and defined - input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids + # only last token for inputs_ids if past is defined in kwargs + if 'past' in kwargs and kwargs['past']: + input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} inputs.update(kwargs) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e5e4926af9..3248763bdb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module): return {"input_ids": input_ids} def _do_output_past(self, outputs): - # TODO: might be better to write a self.do_output_past method for each individual class as is done for - # prepare_inputs_for_generation + # TODO: might be better to write a self.do_output_past method for each + # individual class as is done for prepare_inputs_for_generation has_output_past = hasattr(self.config, 'output_past') and self.config.output_past has_multiple_outputs = len(outputs) > 1 has_mem_len = hasattr(self, 'mem_len')