diff --git a/examples/run_generation.py b/examples/run_generation.py index 2d917660cf..fa52905b7e 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -247,7 +247,11 @@ def main(): out = out[:, len(context_tokens):].tolist() for o in out: text = tokenizer.decode(o, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] + if args.stop_token: + index = text.find(args.stop_token) + if index == -1: + index = None + text = text[:index] print(text)