[Fix] text-classification PL example (#6027)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Bhashithe Abeysinghe
2020-08-06 15:46:43 -04:00
committed by GitHub
parent eb2bd8d6eb
commit ffceef2042
3 changed files with 14 additions and 6 deletions

View File

@@ -73,7 +73,7 @@ class BaseTransformer(pl.LightningModule):
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self.hparams = hparams
self.save_hyperparameters(hparams)
self.step_count = 0
self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
@@ -245,7 +245,7 @@ class BaseTransformer(pl.LightningModule):
class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
@@ -278,6 +278,10 @@ def add_generic_args(parser, root_dir) -> None:
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--gpus", default=0, type=int, help="The number of GPUs allocated for this, it is by default 0 meaning none",
)
parser.add_argument(
"--fp16",
action="store_true",
@@ -291,7 +295,7 @@ def add_generic_args(parser, root_dir) -> None:
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")