Generate: Add text streamer decoding options (#22544)
This commit is contained in:
@@ -42,6 +42,10 @@ class TextStreamer(BaseStreamer):
|
||||
Parameters:
|
||||
tokenizer (`AutoTokenizer`):
|
||||
The tokenized used to decode the tokens.
|
||||
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||
decode_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -59,10 +63,15 @@ class TextStreamer(BaseStreamer):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "AutoTokenizer"):
|
||||
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
self.tokenizer = tokenizer
|
||||
self.skip_prompt = skip_prompt
|
||||
self.decode_kwargs = decode_kwargs
|
||||
|
||||
# variables used in the streaming process
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
self.next_tokens_are_prompt = True
|
||||
|
||||
def put(self, value):
|
||||
"""
|
||||
@@ -73,11 +82,15 @@ class TextStreamer(BaseStreamer):
|
||||
elif len(value.shape) > 1:
|
||||
value = value[0]
|
||||
|
||||
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||
self.next_tokens_are_prompt = False
|
||||
return
|
||||
|
||||
# Add the new token to the cache and decodes the entire thing.
|
||||
self.token_cache.extend(value.tolist())
|
||||
text = self.tokenizer.decode(self.token_cache)
|
||||
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
||||
|
||||
# After symbol for a new line, we flush the cache.
|
||||
# After the symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len :]
|
||||
self.token_cache = []
|
||||
@@ -94,30 +107,34 @@ class TextStreamer(BaseStreamer):
|
||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||
# Flush the cache, if it exists
|
||||
if len(self.token_cache) > 0:
|
||||
text = self.tokenizer.decode(self.token_cache)
|
||||
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
||||
printable_text = text[self.print_len :]
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
else:
|
||||
printable_text = ""
|
||||
|
||||
# Print a newline (and the remaining text, if any)
|
||||
self.next_tokens_are_prompt = True
|
||||
self.on_finalized_text(printable_text, stream_end=True)
|
||||
|
||||
def on_finalized_text(self, token: str, stream_end: bool = False):
|
||||
"""Prints the new text to stdout."""
|
||||
print(token, flush=True, end="" if not stream_end else None)
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
||||
print(text, flush=True, end="" if not stream_end else None)
|
||||
|
||||
|
||||
class TextIteratorStreamer(BaseStreamer):
|
||||
class TextIteratorStreamer(TextStreamer):
|
||||
"""
|
||||
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
||||
useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio
|
||||
demo).
|
||||
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
|
||||
Gradio demo).
|
||||
|
||||
Parameters:
|
||||
tokenizer (`AutoTokenizer`):
|
||||
The tokenized used to decode the tokens.
|
||||
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||
decode_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -142,58 +159,23 @@ class TextIteratorStreamer(BaseStreamer):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "AutoTokenizer"):
|
||||
self.tokenizer = tokenizer
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
self.queue = Queue()
|
||||
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||
self.text_queue = Queue()
|
||||
self.stop_signal = None
|
||||
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
||||
self.text_queue.put(text)
|
||||
if stream_end:
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.queue.get()
|
||||
value = self.text_queue.get()
|
||||
if value == self.stop_signal:
|
||||
raise StopIteration()
|
||||
else:
|
||||
return value
|
||||
|
||||
def put(self, value):
|
||||
"""
|
||||
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
|
||||
"""
|
||||
if len(value.shape) > 1 and value.shape[0] > 1:
|
||||
raise ValueError("TextStreamer only supports batch size 1")
|
||||
elif len(value.shape) > 1:
|
||||
value = value[0]
|
||||
|
||||
# Add the new token to the cache and decodes the entire thing.
|
||||
self.token_cache.extend(value.tolist())
|
||||
text = self.tokenizer.decode(self.token_cache)
|
||||
|
||||
# After symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len :]
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||||
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||||
else:
|
||||
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
||||
self.print_len += len(printable_text)
|
||||
self.queue.put(printable_text)
|
||||
|
||||
def end(self):
|
||||
"""Flushes any remaining cache and puts the stop signal in the queue."""
|
||||
# Flush the cache, if it exists
|
||||
if len(self.token_cache) > 0:
|
||||
text = self.tokenizer.decode(self.token_cache)
|
||||
printable_text = text[self.print_len :]
|
||||
self.token_cache = []
|
||||
self.print_len = 0
|
||||
else:
|
||||
printable_text = ""
|
||||
|
||||
self.queue.put(printable_text)
|
||||
self.queue.put(self.stop_signal) # Put the stop signal
|
||||
|
||||
@@ -23,6 +23,8 @@ from ..test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
@@ -63,3 +65,40 @@ class StreamerTester(unittest.TestCase):
|
||||
streamer_text += new_text
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
def test_text_streamer_skip_prompt(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
|
||||
new_greedy_text = tokenizer.decode(new_greedy_ids[0])
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
||||
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||
streamer_text = cs.out[:-1]
|
||||
|
||||
self.assertEqual(streamer_text, new_greedy_text)
|
||||
|
||||
def test_text_streamer_decode_kwargs(self):
|
||||
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
|
||||
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
|
||||
# `skip_special_tokens=True` has no effect on them
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
||||
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
|
||||
|
||||
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
|
||||
# re-tokenized, must only contain one token
|
||||
streamer_text = cs.out[:-1] # Remove the final "\n"
|
||||
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
|
||||
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))
|
||||
|
||||
Reference in New Issue
Block a user