Lightning Updates for v0.8.5 (#5798)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Nathan Raw
2020-07-17 20:43:06 -06:00
committed by GitHub
parent 615be03f9d
commit 529850ae7b
7 changed files with 73 additions and 97 deletions

View File

@@ -1,14 +1,11 @@
import argparse
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
AdamW,
@@ -42,14 +39,6 @@ MODEL_MODES = {
}
def set_seed(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpus > 0:
torch.cuda.manual_seed_all(args.seed)
class BaseTransformer(pl.LightningModule):
def __init__(
self,
@@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule):
):
"""Initialize a model, tokenizer and config."""
super().__init__()
self.hparams = hparams # TODO: move to self.save_hyperparameters()
# TODO: move to self.save_hyperparameters()
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self.hparams = hparams
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
@@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step() # By default, PL will only step every epoch.
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
self.logger.log_metrics(lrs)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
@@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def train_dataloader(self):
def setup(self, step):
train_batch_size = self.hparams.train_batch_size
dataloader = self.load_dataset("train", train_batch_size)
dataloader = self.get_dataloader("train", train_batch_size)
self.train_loader = dataloader
self.total_steps = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
t_total = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
self.lr_scheduler = scheduler
return dataloader
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.load_dataset("dev", self.hparams.eval_batch_size)
return self.get_dataloader("dev", self.hparams.eval_batch_size)
def test_dataloader(self):
return self.load_dataset("test", self.hparams.eval_batch_size)
return self.get_dataloader("test", self.hparams.eval_batch_size)
def _feature_file(self, mode):
return os.path.join(
@@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
class LoggingCallback(pl.Callback):
@rank_zero_only
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics
@@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****")
rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key])))
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
@@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None:
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--fast_dev_run", action="store_true")
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument("--n_tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument(
"--gradient_accumulation_steps",
dest="accumulate_grad_batches",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--val_check_interval", default=1.0, type=float)
def generic_train(
@@ -283,10 +265,13 @@ def generic_train(
logging_callback=None,
**extra_train_kwargs
):
pl.seed_everything(args.seed)
# init model
set_seed(args)
odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True)
# add custom checkpoints
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
@@ -296,38 +281,25 @@ def generic_train(
train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["use_amp"] = args.fp16
train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
if args.n_tpu_cores > 0:
global xm
import torch_xla.core.xla_model as xm
train_params["num_tpu_cores"] = args.n_tpu_cores
train_params["gpus"] = 0
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer(
logger=logger,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.gpus,
max_epochs=args.num_train_epochs,
early_stop_callback=early_stopping_callback,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[logging_callback] + extra_callbacks,
fast_dev_run=args.fast_dev_run,
val_check_interval=args.val_check_interval,
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
resume_from_checkpoint=args.resume_from_checkpoint,
callbacks=[logging_callback] + extra_callbacks,
logger=logger,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping_callback,
**train_params,
)
if args.do_train:
trainer.fit(model)
trainer.logger.log_hyperparams(args)
trainer.logger.save()
return trainer