From 889d3bfdbbc624815c19f0c4438319e3527d6e78 Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 20 Feb 2020 15:31:17 -0500 Subject: [PATCH] default arg fix (#2937) --- examples/ner/transformer_base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/ner/transformer_base.py b/examples/ner/transformer_base.py index e5cb2f8009..fd119821fa 100644 --- a/examples/ner/transformer_base.py +++ b/examples/ner/transformer_base.py @@ -248,16 +248,22 @@ def generic_train(model, args): filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5 ) - trainer = pl.Trainer( + train_params = dict( accumulate_grad_batches=args.gradient_accumulation_steps, gpus=args.n_gpu, max_epochs=args.num_train_epochs, - use_amp=args.fp16, - amp_level=args.fp16_opt_level, - distributed_backend="ddp", gradient_clip_val=args.max_grad_norm, checkpoint_callback=checkpoint_callback, ) + if args.fp16: + train_params["use_amp"] = args.fp16 + train_params["amp_level"] = args.fp16_opt_level + + if args.n_gpu > 1: + train_params["distributed_backend"] = "ddp" + + trainer = pl.Trainer(**train_params) + if args.do_train: trainer.fit(model)