From e18f786cd57c8cb84a09c6ce5e0d7de0dd8b106e Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 14 Nov 2019 10:06:00 -0500 Subject: [PATCH] Quickstart example showcasing past --- docs/source/quickstart.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index ccba75e7c0..530aff8eb0 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -188,3 +188,35 @@ assert predicted_text == 'Who was Jim Henson? Jim Henson was a man' ``` Examples for each model class of each model architecture (Bert, GPT, GPT-2, Transformer-XL, XLNet and XLM) can be found in the [documentation](#documentation). + +#### Using the past + +GPT-2 as well as some other models (GPT, XLNet, Transfo-XL, CTRL) make use of a `past` or `mems` attribute which can be used to prevent re-computing the key/value pairs when using sequential decoding. It is useful when generating sequences as a big part of the attention mechanism benefits from previous computations. + +Here is a fully-working example using the `past` with `GPT2LMHeadModel` and argmax decoding (which should only be used as an example, as argmax decoding introduces a lot of repetition): + +```python +from transformers import GPT2LMHeadModel, GPT2Tokenizer +import torch + +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +model = GPT2LMHeadModel.from_pretrained('gpt2') + +generated = tokenizer.encode("The Manhattan bridge") +context = torch.tensor([generated]) +past = None + +for i in range(100): + print(i) + output, past = model(context, past=past) + token = torch.argmax(output[0, :]) + + generated += [token.tolist()] + context = token.unsqueeze(0) + +sequence = tokenizer.decode(generated) + +print(sequence) +``` + +The model only requires a single token as input as all the previous tokens' key/value pairs are contained in the `past`. \ No newline at end of file