add and improve comments

This commit is contained in:
patrickvonplaten
2019-12-23 22:04:42 +01:00
parent d891fd0ae0
commit 267587c258
3 changed files with 4 additions and 2 deletions

View File

@@ -732,6 +732,7 @@ class PreTrainedModel(nn.Module):
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past parameter to speed up decoding
if self._has_past(outputs):
past = outputs[1]
@@ -819,6 +820,7 @@ class PreTrainedModel(nn.Module):
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past parameter to speed up decoding
if self._has_past(outputs):
past = outputs[1]