From 267587c258f1972e3743695c61c7f369a47d9a90 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 22:04:42 +0100 Subject: [PATCH] add and improve comments --- src/transformers/modeling_ctrl.py | 2 +- src/transformers/modeling_gpt2.py | 2 +- src/transformers/modeling_utils.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index 69e3a1a8da..a0b7cedb3b 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -491,7 +491,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids contain only last token if past is in kwargs and defined + # inputs_ids should only be composed of last token if past is in kwargs and defined input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index d962259d28..98581c670e 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -560,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, **kwargs): - # inputs_ids contain only last token if past is in kwargs and defined + # inputs_ids should only be composed of last token if past is in kwargs and defined input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids inputs = {"input_ids": input_ids} diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2c8d30e85f..bfd2be8220 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -732,6 +732,7 @@ class PreTrainedModel(nn.Module): outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] + # if model has past, then set the past parameter to speed up decoding if self._has_past(outputs): past = outputs[1] @@ -819,6 +820,7 @@ class PreTrainedModel(nn.Module): outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + # if model has past, then set the past parameter to speed up decoding if self._has_past(outputs): past = outputs[1]