Correct quickstart example when using the past

This commit is contained in:
Lysandre
2020-02-10 11:25:56 -05:00
parent 63a5399bc4
commit fd639e5be3

View File

@@ -209,7 +209,7 @@ past = None
for i in range(100): for i in range(100):
print(i) print(i)
output, past = model(context, past=past) output, past = model(context, past=past)
token = torch.argmax(output[0, :]) token = torch.argmax(output[..., -1, :])
generated += [token.tolist()] generated += [token.tolist()]
context = token.unsqueeze(0) context = token.unsqueeze(0)