diff --git a/examples/run_gpt2.py b/examples/run_gpt2.py index 61fbf9f323..4b081d3a1d 100644 --- a/examples/run_gpt2.py +++ b/examples/run_gpt2.py @@ -58,7 +58,7 @@ def run_model(): parser.add_argument("--nsamples", type=int, default=1) parser.add_argument("--batch_size", type=int, default=-1) parser.add_argument("--length", type=int, default=-1) - parser.add_argument("--temperature", type=int, default=1) + parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') args = parser.parse_args()