From 4d3cf0d6028d7576b8c51ba1eda8403e86b42b05 Mon Sep 17 00:00:00 2001 From: Dhanajit Brahma Date: Sun, 7 Apr 2019 16:59:07 +0530 Subject: [PATCH] removing some redundant lines --- examples/run_gpt2.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/examples/run_gpt2.py b/examples/run_gpt2.py index 0289b26702..b22df39b98 100644 --- a/examples/run_gpt2.py +++ b/examples/run_gpt2.py @@ -83,29 +83,29 @@ def run_model(): elif args.length > model.config.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) - while not args.unconditional: - if not args.unconditional: + if not args.unconditional: + while True: raw_text = input("Model prompt >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("Model prompt >>> ") context_tokens = enc.encode(raw_text) - generated = 0 - for _ in range(args.nsamples // args.batch_size): - out = sample_sequence( - model=model, length=args.length, - context=context_tokens if not args.unconditional else None, - start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None, - batch_size=args.batch_size, - temperature=args.temperature, top_k=args.top_k, device=device - ) - out = out[:, len(context_tokens):].tolist() - for i in range(args.batch_size): - generated += 1 - text = enc.decode(out[i]) - print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) - print(text) - print("=" * 80) + generated = 0 + for _ in range(args.nsamples // args.batch_size): + out = sample_sequence( + model=model, length=args.length, + context=context_tokens, + start_token=None, + batch_size=args.batch_size, + temperature=args.temperature, top_k=args.top_k, device=device + ) + out = out[:, len(context_tokens):].tolist() + for i in range(args.batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) if args.unconditional: generated = 0 for _ in range(args.nsamples // args.batch_size): @@ -127,3 +127,4 @@ def run_model(): if __name__ == '__main__': run_model() +