default arg fix (#2937)

This commit is contained in:
srush
2020-02-20 15:31:17 -05:00
committed by GitHub
parent ea8eba35e2
commit 889d3bfdbb

View File

@@ -248,16 +248,22 @@ def generic_train(model, args):
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
)
trainer = pl.Trainer(
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
max_epochs=args.num_train_epochs,
use_amp=args.fp16,
amp_level=args.fp16_opt_level,
distributed_backend="ddp",
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
)
if args.fp16:
train_params["use_amp"] = args.fp16
train_params["amp_level"] = args.fp16_opt_level
if args.n_gpu > 1:
train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer(**train_params)
if args.do_train:
trainer.fit(model)