Universal Assisted Generation: Assisted generation with any assistant model (by Intel Labs) (#33383)

* Update candidate_generator.py

* Update utils.py

* add lookbehind params to _get_candidate_generator

* make fixup

* add unit tests

* fix failing tests

* add docstrings

* fix docstrings; remove non-optimized AnyTokenizer

* added any tokenizer generation correctness test

* make fixup

* fix assertion syntax

* PR review fixes

* address additional PR comments

* fix tests

* remove stropping criteria arg

* make fixup

* add AssistantConfig

* fix prev_tokens branching

* pass tokenizers through `generate()`kwargs

* fix lookbehind values; tokenizer params WIP

* fixup

* AssistantConfig

* remove AssistantConfig; apply PR suggestions

* restructure tests

* fixup

* fix assistant_tokenizer arg validation

* fixup

* fix tests in TestAssistedCandidateGeneratorDifferentTokenizers

* fix class docstring

* PR suggestions

* doc

* doc update and improvements to `_validate_assistant()`

---------

Co-authored-by: mosheber <moshe.berchansky@intel.com>
This commit is contained in:
Daniel Korat
2024-10-10 15:41:53 +03:00
committed by GitHub
parent dda3f91d06
commit fb0c6b521d
4 changed files with 440 additions and 10 deletions

View File

@@ -408,14 +408,24 @@ For the complete list of the available parameters, refer to the [API documentati
### Speculative Decoding
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
`do_sample=True`, then the token validation with resampling introduced in the
[speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
assistant model (ideally a much smaller one), to generate a few candidate tokens. The main model then validates the candidate
tokens in a single forward pass, which speeds up the decoding process. If `do_sample=True`, then the token validation with
resampling introduced in the [speculative decoding paper](https://arxiv.org/pdf/2211.17192.pdf) is used.
Assisted decoding assumes the main and assistant models have the same tokenizer, otherwise, see Universal Assisted Decoding below.
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
#### Universal Assisted Decoding
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.
To enable assisted decoding, set the `assistant_model` argument with a model.
```python
@@ -435,6 +445,26 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "google/gemma-2-9b"
>>> assistant_checkpoint = "double7/vicuna-68m"
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
@@ -458,6 +488,7 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
### DoLa Decoding
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the