Merge pull request #494 from SudoSharma/patch-1
Fix indentation for unconditional generation
This commit is contained in:
@@ -107,25 +107,25 @@ def run_model():
|
|||||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
print(text)
|
print(text)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
if args.unconditional:
|
if args.unconditional:
|
||||||
generated = 0
|
generated = 0
|
||||||
for _ in range(args.nsamples // args.batch_size):
|
for _ in range(args.nsamples // args.batch_size):
|
||||||
out = sample_sequence(
|
out = sample_sequence(
|
||||||
model=model, length=args.length,
|
model=model, length=args.length,
|
||||||
context=None,
|
context=None,
|
||||||
start_token=enc.encoder['<|endoftext|>'],
|
start_token=enc.encoder['<|endoftext|>'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
temperature=args.temperature, top_k=args.top_k, device=device
|
temperature=args.temperature, top_k=args.top_k, device=device
|
||||||
)
|
)
|
||||||
out = out[:,1:].tolist()
|
out = out[:,1:].tolist()
|
||||||
for i in range(args.batch_size):
|
for i in range(args.batch_size):
|
||||||
generated += 1
|
generated += 1
|
||||||
text = enc.decode(out[i])
|
text = enc.decode(out[i])
|
||||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
print(text)
|
print(text)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
if args.unconditional:
|
if args.unconditional:
|
||||||
break
|
break
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_model()
|
run_model()
|
||||||
|
|||||||
Reference in New Issue
Block a user