add and improve comments
This commit is contained in:
@@ -491,7 +491,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
# inputs_ids contain only last token if past is in kwargs and defined
|
# inputs_ids should only be composed of last token if past is in kwargs and defined
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
||||||
|
|
||||||
inputs = {"input_ids": input_ids}
|
inputs = {"input_ids": input_ids}
|
||||||
|
|||||||
@@ -560,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
# inputs_ids contain only last token if past is in kwargs and defined
|
# inputs_ids should only be composed of last token if past is in kwargs and defined
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
||||||
|
|
||||||
inputs = {"input_ids": input_ids}
|
inputs = {"input_ids": input_ids}
|
||||||
|
|||||||
@@ -732,6 +732,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
|
# if model has past, then set the past parameter to speed up decoding
|
||||||
if self._has_past(outputs):
|
if self._has_past(outputs):
|
||||||
past = outputs[1]
|
past = outputs[1]
|
||||||
|
|
||||||
@@ -819,6 +820,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
# if model has past, then set the past parameter to speed up decoding
|
||||||
if self._has_past(outputs):
|
if self._has_past(outputs):
|
||||||
past = outputs[1]
|
past = outputs[1]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user