From 4b543c3007a57441550d87d5d61f06f7938d7140 Mon Sep 17 00:00:00 2001 From: Lorenzo Ampil Date: Sun, 22 Sep 2019 21:38:38 +0800 Subject: [PATCH] Add option to use a 'stop token' which will be used to truncate the output text to everything till right before the 'stop token' --- examples/run_generation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/run_generation.py b/examples/run_generation.py index a2a8f29103..27bc14e313 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -145,6 +145,8 @@ def main(): help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") + parser.add_argument('--stop_token', type=str, default=None, + help="Token at which text generation is stopped") args = parser.parse_args() args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") @@ -185,6 +187,7 @@ def main(): ) out = out[0, len(context_tokens):].tolist() text = tokenizer.decode(out, clean_up_tokenization_spaces=True) + text = text[: text.find(args.stop_token) if args.stop_token else None] print(text) if args.prompt: break