From fc84bd5254ed0f89f50c1491cc5b68135a8a5125 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 23:32:44 +0100 Subject: [PATCH] adapt style to predefined style layout --- src/transformers/modeling_ctrl.py | 2 +- src/transformers/modeling_gpt2.py | 2 +- src/transformers/modeling_transfo_xl.py | 4 ++-- src/transformers/modeling_utils.py | 4 ++-- src/transformers/modeling_xlnet.py | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index 91cf62b3b5..d069209a48 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def prepare_inputs_for_generation(self, input_ids, **kwargs): # only last token for inputs_ids if past is defined in kwargs - if 'past' in kwargs and kwargs['past']: + if "past" in kwargs and kwargs["past"]: input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 6e9b5066e9..7f8d1454de 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def prepare_inputs_for_generation(self, input_ids, **kwargs): # only last token for inputs_ids if past is defined in kwargs - if 'past' in kwargs and kwargs['past']: + if "past" in kwargs and kwargs["past"]: input_ids = input_ids[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 938ee86ec3..394e656774 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): inputs = {"input_ids": input_ids} # if past is defined in model kwargs then use it for faster decoding - if 'past' in model_kwargs and model_kwargs['past']: - inputs['mems'] = model_kwargs['past'] + if "past" in model_kwargs and model_kwargs["past"]: + inputs["mems"] = model_kwargs["past"] return inputs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e24b2b359..786f03b9fe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module): return {"input_ids": input_ids} def _do_output_past(self, outputs): - has_output_past = hasattr(self.config, 'output_past') and self.config.output_past - has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len + has_output_past = hasattr(self.config, "output_past") and self.config.output_past + has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len if has_output_past and not has_mem_len and len(outputs) > 1: return True diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index dc38821058..8b00fa7e37 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} # if past is defined in model kwargs then use it for faster decoding - if 'past' in model_kwargs and model_kwargs['past']: - inputs['mems'] = model_kwargs['past'] + if "past" in model_kwargs and model_kwargs["past"]: + inputs["mems"] = model_kwargs["past"] return inputs