Generate: Add text streamer decoding options (#22544)
This commit is contained in:
@@ -42,6 +42,10 @@ class TextStreamer(BaseStreamer):
|
|||||||
Parameters:
|
Parameters:
|
||||||
tokenizer (`AutoTokenizer`):
|
tokenizer (`AutoTokenizer`):
|
||||||
The tokenized used to decode the tokens.
|
The tokenized used to decode the tokens.
|
||||||
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||||
|
decode_kwargs (`dict`, *optional*):
|
||||||
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -59,10 +63,15 @@ class TextStreamer(BaseStreamer):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer: "AutoTokenizer"):
|
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.skip_prompt = skip_prompt
|
||||||
|
self.decode_kwargs = decode_kwargs
|
||||||
|
|
||||||
|
# variables used in the streaming process
|
||||||
self.token_cache = []
|
self.token_cache = []
|
||||||
self.print_len = 0
|
self.print_len = 0
|
||||||
|
self.next_tokens_are_prompt = True
|
||||||
|
|
||||||
def put(self, value):
|
def put(self, value):
|
||||||
"""
|
"""
|
||||||
@@ -73,11 +82,15 @@ class TextStreamer(BaseStreamer):
|
|||||||
elif len(value.shape) > 1:
|
elif len(value.shape) > 1:
|
||||||
value = value[0]
|
value = value[0]
|
||||||
|
|
||||||
|
if self.skip_prompt and self.next_tokens_are_prompt:
|
||||||
|
self.next_tokens_are_prompt = False
|
||||||
|
return
|
||||||
|
|
||||||
# Add the new token to the cache and decodes the entire thing.
|
# Add the new token to the cache and decodes the entire thing.
|
||||||
self.token_cache.extend(value.tolist())
|
self.token_cache.extend(value.tolist())
|
||||||
text = self.tokenizer.decode(self.token_cache)
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
||||||
|
|
||||||
# After symbol for a new line, we flush the cache.
|
# After the symbol for a new line, we flush the cache.
|
||||||
if text.endswith("\n"):
|
if text.endswith("\n"):
|
||||||
printable_text = text[self.print_len :]
|
printable_text = text[self.print_len :]
|
||||||
self.token_cache = []
|
self.token_cache = []
|
||||||
@@ -94,30 +107,34 @@ class TextStreamer(BaseStreamer):
|
|||||||
"""Flushes any remaining cache and prints a newline to stdout."""
|
"""Flushes any remaining cache and prints a newline to stdout."""
|
||||||
# Flush the cache, if it exists
|
# Flush the cache, if it exists
|
||||||
if len(self.token_cache) > 0:
|
if len(self.token_cache) > 0:
|
||||||
text = self.tokenizer.decode(self.token_cache)
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
||||||
printable_text = text[self.print_len :]
|
printable_text = text[self.print_len :]
|
||||||
self.token_cache = []
|
self.token_cache = []
|
||||||
self.print_len = 0
|
self.print_len = 0
|
||||||
else:
|
else:
|
||||||
printable_text = ""
|
printable_text = ""
|
||||||
|
|
||||||
# Print a newline (and the remaining text, if any)
|
self.next_tokens_are_prompt = True
|
||||||
self.on_finalized_text(printable_text, stream_end=True)
|
self.on_finalized_text(printable_text, stream_end=True)
|
||||||
|
|
||||||
def on_finalized_text(self, token: str, stream_end: bool = False):
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||||
"""Prints the new text to stdout."""
|
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
||||||
print(token, flush=True, end="" if not stream_end else None)
|
print(text, flush=True, end="" if not stream_end else None)
|
||||||
|
|
||||||
|
|
||||||
class TextIteratorStreamer(BaseStreamer):
|
class TextIteratorStreamer(TextStreamer):
|
||||||
"""
|
"""
|
||||||
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
||||||
useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio
|
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
|
||||||
demo).
|
Gradio demo).
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
tokenizer (`AutoTokenizer`):
|
tokenizer (`AutoTokenizer`):
|
||||||
The tokenized used to decode the tokens.
|
The tokenized used to decode the tokens.
|
||||||
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||||||
|
decode_kwargs (`dict`, *optional*):
|
||||||
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -142,58 +159,23 @@ class TextIteratorStreamer(BaseStreamer):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer: "AutoTokenizer"):
|
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||||
self.tokenizer = tokenizer
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||||
self.token_cache = []
|
self.text_queue = Queue()
|
||||||
self.print_len = 0
|
|
||||||
self.queue = Queue()
|
|
||||||
self.stop_signal = None
|
self.stop_signal = None
|
||||||
|
|
||||||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||||
|
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
||||||
|
self.text_queue.put(text)
|
||||||
|
if stream_end:
|
||||||
|
self.text_queue.put(self.stop_signal)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
value = self.queue.get()
|
value = self.text_queue.get()
|
||||||
if value == self.stop_signal:
|
if value == self.stop_signal:
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def put(self, value):
|
|
||||||
"""
|
|
||||||
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
|
|
||||||
"""
|
|
||||||
if len(value.shape) > 1 and value.shape[0] > 1:
|
|
||||||
raise ValueError("TextStreamer only supports batch size 1")
|
|
||||||
elif len(value.shape) > 1:
|
|
||||||
value = value[0]
|
|
||||||
|
|
||||||
# Add the new token to the cache and decodes the entire thing.
|
|
||||||
self.token_cache.extend(value.tolist())
|
|
||||||
text = self.tokenizer.decode(self.token_cache)
|
|
||||||
|
|
||||||
# After symbol for a new line, we flush the cache.
|
|
||||||
if text.endswith("\n"):
|
|
||||||
printable_text = text[self.print_len :]
|
|
||||||
self.token_cache = []
|
|
||||||
self.print_len = 0
|
|
||||||
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
|
||||||
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
|
||||||
else:
|
|
||||||
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
|
||||||
self.print_len += len(printable_text)
|
|
||||||
self.queue.put(printable_text)
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
"""Flushes any remaining cache and puts the stop signal in the queue."""
|
|
||||||
# Flush the cache, if it exists
|
|
||||||
if len(self.token_cache) > 0:
|
|
||||||
text = self.tokenizer.decode(self.token_cache)
|
|
||||||
printable_text = text[self.print_len :]
|
|
||||||
self.token_cache = []
|
|
||||||
self.print_len = 0
|
|
||||||
else:
|
|
||||||
printable_text = ""
|
|
||||||
|
|
||||||
self.queue.put(printable_text)
|
|
||||||
self.queue.put(self.stop_signal) # Put the stop signal
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from ..test_modeling_common import ids_tensor
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
@@ -63,3 +65,40 @@ class StreamerTester(unittest.TestCase):
|
|||||||
streamer_text += new_text
|
streamer_text += new_text
|
||||||
|
|
||||||
self.assertEqual(streamer_text, greedy_text)
|
self.assertEqual(streamer_text, greedy_text)
|
||||||
|
|
||||||
|
def test_text_streamer_skip_prompt(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
model.config.eos_token_id = -1
|
||||||
|
|
||||||
|
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||||
|
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||||
|
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
|
||||||
|
new_greedy_text = tokenizer.decode(new_greedy_ids[0])
|
||||||
|
|
||||||
|
with CaptureStdout() as cs:
|
||||||
|
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
||||||
|
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||||
|
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||||
|
streamer_text = cs.out[:-1]
|
||||||
|
|
||||||
|
self.assertEqual(streamer_text, new_greedy_text)
|
||||||
|
|
||||||
|
def test_text_streamer_decode_kwargs(self):
|
||||||
|
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
|
||||||
|
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
|
||||||
|
# `skip_special_tokens=True` has no effect on them
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(torch_device)
|
||||||
|
model.config.eos_token_id = -1
|
||||||
|
|
||||||
|
input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
|
||||||
|
with CaptureStdout() as cs:
|
||||||
|
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
||||||
|
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
|
||||||
|
|
||||||
|
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
|
||||||
|
# re-tokenized, must only contain one token
|
||||||
|
streamer_text = cs.out[:-1] # Remove the final "\n"
|
||||||
|
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
|
||||||
|
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))
|
||||||
|
|||||||
Reference in New Issue
Block a user