Quickstart example showcasing past
This commit is contained in:
@@ -188,3 +188,35 @@ assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'
|
|||||||
```
|
```
|
||||||
|
|
||||||
Examples for each model class of each model architecture (Bert, GPT, GPT-2, Transformer-XL, XLNet and XLM) can be found in the [documentation](#documentation).
|
Examples for each model class of each model architecture (Bert, GPT, GPT-2, Transformer-XL, XLNet and XLM) can be found in the [documentation](#documentation).
|
||||||
|
|
||||||
|
#### Using the past
|
||||||
|
|
||||||
|
GPT-2 as well as some other models (GPT, XLNet, Transfo-XL, CTRL) make use of a `past` or `mems` attribute which can be used to prevent re-computing the key/value pairs when using sequential decoding. It is useful when generating sequences as a big part of the attention mechanism benefits from previous computations.
|
||||||
|
|
||||||
|
Here is a fully-working example using the `past` with `GPT2LMHeadModel` and argmax decoding (which should only be used as an example, as argmax decoding introduces a lot of repetition):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
generated = tokenizer.encode("The Manhattan bridge")
|
||||||
|
context = torch.tensor([generated])
|
||||||
|
past = None
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
print(i)
|
||||||
|
output, past = model(context, past=past)
|
||||||
|
token = torch.argmax(output[0, :])
|
||||||
|
|
||||||
|
generated += [token.tolist()]
|
||||||
|
context = token.unsqueeze(0)
|
||||||
|
|
||||||
|
sequence = tokenizer.decode(generated)
|
||||||
|
|
||||||
|
print(sequence)
|
||||||
|
```
|
||||||
|
|
||||||
|
The model only requires a single token as input as all the previous tokens' key/value pairs are contained in the `past`.
|
||||||
Reference in New Issue
Block a user