fix run_gpt2.py

This commit is contained in:
Benjamin Mann
2019-04-08 17:20:35 -07:00
parent 94980b529f
commit fd8a3556f0

View File

@@ -83,7 +83,8 @@ def run_model():
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while not args.unconditional:
while True:
context_tokens = []
if not args.unconditional:
raw_text = input("Model prompt >>> ")
while not raw_text:
@@ -106,6 +107,8 @@ def run_model():
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
break
if __name__ == '__main__':
run_model()