@@ -83,7 +83,8 @@ def run_model():
|
|||||||
elif args.length > model.config.n_ctx:
|
elif args.length > model.config.n_ctx:
|
||||||
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
|
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
|
||||||
|
|
||||||
while not args.unconditional:
|
while True:
|
||||||
|
context_tokens = []
|
||||||
if not args.unconditional:
|
if not args.unconditional:
|
||||||
raw_text = input("Model prompt >>> ")
|
raw_text = input("Model prompt >>> ")
|
||||||
while not raw_text:
|
while not raw_text:
|
||||||
@@ -106,6 +107,8 @@ def run_model():
|
|||||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
print(text)
|
print(text)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
if args.unconditional:
|
||||||
|
break
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_model()
|
run_model()
|
||||||
|
|||||||
Reference in New Issue
Block a user