From a17841ac4945631e4e13c072fa2a329b98ebb8b6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 3 Apr 2023 19:49:38 +0200 Subject: [PATCH] Generate: Enable easier TextStreamer customization (#22516) --- src/transformers/generation/streamers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index 78d98666b3..06f7be9d63 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -88,7 +88,7 @@ class TextStreamer(BaseStreamer): printable_text = text[self.print_len : text.rfind(" ") + 1] self.print_len += len(printable_text) - print(printable_text, flush=True, end="") + self.on_finalized_text(printable_text) def end(self): """Flushes any remaining cache and prints a newline to stdout.""" @@ -102,7 +102,11 @@ class TextStreamer(BaseStreamer): printable_text = "" # Print a newline (and the remaining text, if any) - print(printable_text, flush=True) + self.on_finalized_text(printable_text, stream_end=True) + + def on_finalized_text(self, token: str, stream_end: bool = False): + """Prints the new text to stdout.""" + print(token, flush=True, end="" if not stream_end else None) class TextIteratorStreamer(BaseStreamer):