From 7e0c5c731a0f0178d09c9827f486d68ac9ca9848 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Mon, 23 Dec 2019 22:33:45 +0100 Subject: [PATCH] changed do_output_past function to check for self.config.output_past instead of self.output_past --- src/transformers/modeling_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9a07e56368..c727bb5f2e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} - def _has_past(self, outputs): - # TODO: might be better to write a self.has_past method for each individual class as is done for + def _do_output_past(self, outputs): + # TODO: might be better to write a self.do_output_past method for each individual class as is done for # prepare_inputs_for_generation - if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1: + if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'): return True # TODO: Add cases for (xlnet, transfo_xl) using mem_len return False @@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module): next_token_logits = outputs[0][:, -1, :] # if model has past, then set the past variable to speed up decoding - if self._has_past(outputs): + if self._do_output_past(outputs): past = outputs[1] # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) @@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module): scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) # if model has past, then set the past variable to speed up decoding - if self._has_past(outputs): + if self._do_output_past(outputs): past = outputs[1] # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)