changed do_output_past function to check for self.config.output_past instead of self.output_past
This commit is contained in:
@@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module):
|
|||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
def _has_past(self, outputs):
|
def _do_output_past(self, outputs):
|
||||||
# TODO: might be better to write a self.has_past method for each individual class as is done for
|
# TODO: might be better to write a self.do_output_past method for each individual class as is done for
|
||||||
# prepare_inputs_for_generation
|
# prepare_inputs_for_generation
|
||||||
if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1:
|
if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'):
|
||||||
return True
|
return True
|
||||||
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
|
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
|
||||||
return False
|
return False
|
||||||
@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
# if model has past, then set the past variable to speed up decoding
|
# if model has past, then set the past variable to speed up decoding
|
||||||
if self._has_past(outputs):
|
if self._do_output_past(outputs):
|
||||||
past = outputs[1]
|
past = outputs[1]
|
||||||
|
|
||||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||||
@@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# if model has past, then set the past variable to speed up decoding
|
# if model has past, then set the past variable to speed up decoding
|
||||||
if self._has_past(outputs):
|
if self._do_output_past(outputs):
|
||||||
past = outputs[1]
|
past = outputs[1]
|
||||||
|
|
||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
|
|||||||
Reference in New Issue
Block a user