Pipeline: simple API for assisted generation (#34504)

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante
2025-01-08 17:08:02 +00:00
committed by GitHub
parent 3f483beab9
commit 76da6ca034
14 changed files with 172 additions and 18 deletions

View File

@@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
<Tip>
If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`
```python
>>> from transformers import pipeline
>>> import torch
>>> pipe = pipeline(
... "text-generation",
... model="meta-llama/Llama-3.1-8B",
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
... torch_dtype=torch.bfloat16
>>> )
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
>>> pipe_output[0]["generated_text"]
'Once upon a time, 3D printing was a niche technology that was only'
```
</Tip>
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.