Universal Assisted Generation: Assisted generation with any assistant model (by Intel Labs) (#33383)
* Update candidate_generator.py * Update utils.py * add lookbehind params to _get_candidate_generator * make fixup * add unit tests * fix failing tests * add docstrings * fix docstrings; remove non-optimized AnyTokenizer * added any tokenizer generation correctness test * make fixup * fix assertion syntax * PR review fixes * address additional PR comments * fix tests * remove stropping criteria arg * make fixup * add AssistantConfig * fix prev_tokens branching * pass tokenizers through `generate()`kwargs * fix lookbehind values; tokenizer params WIP * fixup * AssistantConfig * remove AssistantConfig; apply PR suggestions * restructure tests * fixup * fix assistant_tokenizer arg validation * fixup * fix tests in TestAssistedCandidateGeneratorDifferentTokenizers * fix class docstring * PR suggestions * doc * doc update and improvements to `_validate_assistant()` --------- Co-authored-by: mosheber <moshe.berchansky@intel.com>
This commit is contained in:
@@ -408,14 +408,24 @@ For the complete list of the available parameters, refer to the [API documentati
|
||||
### Speculative Decoding
|
||||
|
||||
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
|
||||
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
|
||||
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
|
||||
`do_sample=True`, then the token validation with resampling introduced in the
|
||||
[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
|
||||
assistant model (ideally a much smaller one), to generate a few candidate tokens. The main model then validates the candidate
|
||||
tokens in a single forward pass, which speeds up the decoding process. If `do_sample=True`, then the token validation with
|
||||
resampling introduced in the [speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
|
||||
Assisted decoding assumes the main and assistant models have the same tokenizer, otherwise, see Universal Assisted Decoding below.
|
||||
|
||||
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
|
||||
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
||||
|
||||
#### Universal Assisted Decoding
|
||||
|
||||
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
|
||||
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
|
||||
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
|
||||
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
|
||||
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
|
||||
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
|
||||
to ensure the new tokens include the correct prompt suffix.
|
||||
|
||||
To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
|
||||
```python
|
||||
@@ -435,6 +445,26 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> prompt = "Alice and Bob"
|
||||
>>> checkpoint = "google/gemma-2-9b"
|
||||
>>> assistant_checkpoint = "double7/vicuna-68m"
|
||||
|
||||
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||
|
||||
@@ -458,6 +488,7 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
||||
|
||||
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
|
||||
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
|
||||
|
||||
### DoLa Decoding
|
||||
|
||||
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
|
||||
|
||||
Reference in New Issue
Block a user