default arg fix (#2937)
This commit is contained in:
@@ -248,16 +248,22 @@ def generic_train(model, args):
|
|||||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = pl.Trainer(
|
train_params = dict(
|
||||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||||
gpus=args.n_gpu,
|
gpus=args.n_gpu,
|
||||||
max_epochs=args.num_train_epochs,
|
max_epochs=args.num_train_epochs,
|
||||||
use_amp=args.fp16,
|
|
||||||
amp_level=args.fp16_opt_level,
|
|
||||||
distributed_backend="ddp",
|
|
||||||
gradient_clip_val=args.max_grad_norm,
|
gradient_clip_val=args.max_grad_norm,
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
)
|
)
|
||||||
|
if args.fp16:
|
||||||
|
train_params["use_amp"] = args.fp16
|
||||||
|
train_params["amp_level"] = args.fp16_opt_level
|
||||||
|
|
||||||
|
if args.n_gpu > 1:
|
||||||
|
train_params["distributed_backend"] = "ddp"
|
||||||
|
|
||||||
|
trainer = pl.Trainer(**train_params)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user