Token healing (#30081)

* token healing impl + trie with extensions

* make fixup

* prefix-robust space tokenization

* examples readme and requirements

* make fixup

* allow input prompt and model

* redundant defaults

* Specialized Trie

* make fixup

* updated tests with new inherited Tree

* input ids to auto device_map

* rm unused import

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* naming convention

* Revert "naming convention"

This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0.

* naming convention

* last -hopefully- changes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Ahmed Moubtahij
2024-06-03 04:53:15 -04:00
committed by GitHub
parent 5b5b48b11d
commit 39b2ff69d6
7 changed files with 324 additions and 5 deletions

View File

@@ -0,0 +1,40 @@
<!-- back to top link -->
<a name="readme-top"></a>
<!-- ABOUT THE PROJECT -->
## What is token healing?
Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models.
Example: given a completion prompt with a partial url ending with `:`, the model might have seen the expected completion `://` as a _single_ token in training. However, the prompt's tail token `:` tells it that the next token is not `//`, and so it looks for wrong completions. Such errors compound in auto-regressive language models.
Debiasing token boundaries also addresses output sensitivity to prompts ending with whitespace.
A more thorough explanation can be found on [The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg](https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38).
## Usage
```py
prompt = 'The link is <a href="http:'
raw_output = generate(prompt, completion_model, tokenizer, token_healing=False)
# The link is <a href="http:&#47;&#47;www&#47;dailymail&#
# The model saw '://' as a single token in training. Seeing a prompt ending with `:` tells it that the
# next token is likely not `//`, because otherwise it would've seen `://`.
# Thus, it completes with a token other than `//`, in this case, `&`.
healed_output = generate(prompt, completion_model, tokenizer, token_healing=True)
# The link is <a href="http://www.365doki.com/post/3699
# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(test_prompts, return_tensors='pt', padding=True).input_ids.cuda()
healed_ids = model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
# outputs the healed prompts without further completion/generation
```
See `run_token_healing.py` for the full example.
<p align="right">(<a href="#readme-top">back to top</a>)</p>

View File

@@ -0,0 +1,62 @@
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def generate(inputs, model, tokenizer, token_healing):
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, device_map="auto").input_ids
generation_config = GenerationConfig(
max_new_tokens=8,
token_healing=token_healing,
pad_token_id=model.config.pad_token_id,
repetition_penalty=1.1,
)
output = model.generate(inputs=input_ids, generation_config=generation_config)
return tokenizer.batch_decode(output, skip_special_tokens=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str)
parser.add_argument("--model_name_or_path", type=str, default="TheBloke/deepseek-llm-7B-base-GPTQ")
args = parser.parse_args()
prompts = (
[args.prompt]
if args.prompt
else [
'An example ["like this"] and another example [',
'The link is <a href="http:',
'The link is <a href="http', # test aggressive healing http->https
"I read a book about ", # test trailing whitespace
"I read a book about", # test nothing to heal
]
)
model_name_or_path = args.model_name_or_path
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
use_cache=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
raw_output = generate(prompts, completion_model, tokenizer, token_healing=False)
healed_output = generate(prompts, completion_model, tokenizer, token_healing=True)
for p, a, b in zip(prompts, raw_output, healed_output):
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}")
# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda()
healed_ids = completion_model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
print("\nhealed prompts:")
for p in healed_prompts:
print(p)
if __name__ == "__main__":
main()