From 529850ae7bca0ff388778c3c0d66240834cf56c3 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Fri, 17 Jul 2020 20:43:06 -0600 Subject: [PATCH] Lightning Updates for v0.8.5 (#5798) Co-authored-by: Sam Shleifer --- examples/lightning_base.py | 122 +++++++++------------- examples/requirements.txt | 2 +- examples/seq2seq/README.md | 6 +- examples/seq2seq/finetune.py | 25 +++-- examples/seq2seq/finetune.sh | 1 - examples/seq2seq/test_seq2seq_examples.py | 12 +-- examples/seq2seq/train_mbart_cc25_enro.sh | 2 +- 7 files changed, 73 insertions(+), 97 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 988d292b09..1124f57662 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -1,14 +1,11 @@ import argparse import logging import os -import random from pathlib import Path from typing import Any, Dict -import numpy as np import pytorch_lightning as pl -import torch -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.utilities import rank_zero_info from transformers import ( AdamW, @@ -42,14 +39,6 @@ MODEL_MODES = { } -def set_seed(args: argparse.Namespace): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.gpus > 0: - torch.cuda.manual_seed_all(args.seed) - - class BaseTransformer(pl.LightningModule): def __init__( self, @@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule): ): """Initialize a model, tokenizer and config.""" super().__init__() - self.hparams = hparams # TODO: move to self.save_hyperparameters() + # TODO: move to self.save_hyperparameters() + # self.save_hyperparameters() + # can also expand arguments into trainer signature for easier reading + + self.hparams = hparams self.step_count = 0 self.tfmr_ckpts = {} self.output_dir = Path(self.hparams.output_dir) @@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule): ] optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) self.opt = optimizer - return [optimizer] - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): - if self.trainer.use_tpu: - xm.optimizer_step(optimizer) - else: - optimizer.step() - optimizer.zero_grad() - self.lr_scheduler.step() # By default, PL will only step every epoch. - lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())} - self.logger.log_metrics(lrs) + scheduler = get_linear_schedule_with_warmup( + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps + ) + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + return [optimizer], [scheduler] def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) @@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule): def test_epoch_end(self, outputs): return self.validation_end(outputs) - def train_dataloader(self): + def setup(self, step): train_batch_size = self.hparams.train_batch_size - dataloader = self.load_dataset("train", train_batch_size) + dataloader = self.get_dataloader("train", train_batch_size) + self.train_loader = dataloader + self.total_steps = ( + (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus))) + // self.hparams.accumulate_grad_batches + * float(self.hparams.max_epochs) + ) - t_total = ( - (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu))) - // self.hparams.gradient_accumulation_steps - * float(self.hparams.num_train_epochs) - ) - scheduler = get_linear_schedule_with_warmup( - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total - ) - self.lr_scheduler = scheduler - return dataloader + def train_dataloader(self): + return self.train_loader def val_dataloader(self): - return self.load_dataset("dev", self.hparams.eval_batch_size) + return self.get_dataloader("dev", self.hparams.eval_batch_size) def test_dataloader(self): - return self.load_dataset("test", self.hparams.eval_batch_size) + return self.get_dataloader("test", self.hparams.eval_batch_size) def _feature_file(self, mode): return os.path.join( @@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule): parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") - parser.add_argument( - "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." - ) - + parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) parser.add_argument("--train_batch_size", default=32, type=int) parser.add_argument("--eval_batch_size", default=32, type=int) class LoggingCallback(pl.Callback): - @rank_zero_only + def on_batch_end(self, trainer, pl_module): + lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())} + pl_module.logger.log_metrics(lrs) + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): rank_zero_info("***** Validation results *****") metrics = trainer.callback_metrics @@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback): if key not in ["log", "progress_bar"]: rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) - @rank_zero_only def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - logger.info("***** Test results *****") + rank_zero_info("***** Test results *****") metrics = trainer.callback_metrics # Log and save results to file output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") with open(output_test_results_file, "w") as writer: for key in sorted(metrics): if key not in ["log", "progress_bar"]: - logger.info("{} = {}\n".format(key, str(metrics[key]))) + rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) writer.write("{} = {}\n".format(key, str(metrics[key]))) @@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None: parser.add_argument( "--fp16_opt_level", type=str, - default="O1", + default="O2", help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) - parser.add_argument("--fast_dev_run", action="store_true") - parser.add_argument("--gpus", type=int, default=1) - parser.add_argument("--n_tpu_cores", type=int, default=0) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0) + parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") parser.add_argument( "--gradient_accumulation_steps", + dest="accumulate_grad_batches", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") - parser.add_argument("--resume_from_checkpoint", type=str, default=None) - parser.add_argument("--val_check_interval", default=1.0, type=float) def generic_train( @@ -283,10 +265,13 @@ def generic_train( logging_callback=None, **extra_train_kwargs ): + pl.seed_everything(args.seed) + # init model - set_seed(args) odir = Path(model.hparams.output_dir) odir.mkdir(exist_ok=True) + + # add custom checkpoints if checkpoint_callback is None: checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 @@ -296,38 +281,25 @@ def generic_train( train_params = {} + # TODO: remove with PyTorch 1.6 since pl uses native amp if args.fp16: - train_params["use_amp"] = args.fp16 + train_params["precision"] = 16 train_params["amp_level"] = args.fp16_opt_level - if args.n_tpu_cores > 0: - global xm - import torch_xla.core.xla_model as xm - - train_params["num_tpu_cores"] = args.n_tpu_cores - train_params["gpus"] = 0 - if args.gpus > 1: train_params["distributed_backend"] = "ddp" - trainer = pl.Trainer( - logger=logger, - accumulate_grad_batches=args.gradient_accumulation_steps, - gpus=args.gpus, - max_epochs=args.num_train_epochs, - early_stop_callback=early_stopping_callback, - gradient_clip_val=args.max_grad_norm, - checkpoint_callback=checkpoint_callback, - callbacks=[logging_callback] + extra_callbacks, - fast_dev_run=args.fast_dev_run, - val_check_interval=args.val_check_interval, + trainer = pl.Trainer.from_argparse_args( + args, weights_summary=None, - resume_from_checkpoint=args.resume_from_checkpoint, + callbacks=[logging_callback] + extra_callbacks, + logger=logger, + checkpoint_callback=checkpoint_callback, + early_stop_callback=early_stopping_callback, **train_params, ) if args.do_train: trainer.fit(model) - trainer.logger.log_hyperparams(args) - trainer.logger.save() + return trainer diff --git a/examples/requirements.txt b/examples/requirements.txt index 6ab5c2c05a..028dd7f8fd 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,7 +5,7 @@ psutil sacrebleu rouge-score tensorflow_datasets -pytorch-lightning==0.8.1 +pytorch-lightning==0.8.5 matplotlib git-python==1.0.3 faiss diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index fbeab97cbd..fdf3e83617 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -60,7 +60,7 @@ Summarization Tips: - If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. - For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. -- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. +- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). @@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') ``` #### XSUM Shared Task -Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration. +Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration. Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! ```bash @@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \ --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ --num_train_epochs 6 \ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ - --logger wandb + --logger_name wandb ``` You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 16b7335b2f..cd33892680 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer): dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) t_total = ( (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) - // self.hparams.gradient_accumulation_steps - * float(self.hparams.num_train_epochs) + // self.hparams.accumulate_grad_batches + * float(self.hparams.max_epochs) ) scheduler = get_linear_schedule_with_warmup( self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total @@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer): parser.add_argument("--freeze_encoder", action="store_true") parser.add_argument("--freeze_embeds", action="store_true") parser.add_argument("--sortish_sampler", action="store_true", default=False) - parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default") + parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") @@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer): ) parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False) - return parser @@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule: model: SummarizationModule = SummarizationModule(args) else: model: SummarizationModule = TranslationModule(args) + + dataset = Path(args.data_dir).name if ( - args.logger == "default" + args.logger_name == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var") ): logger = True # don't pollute wandb logs unnecessarily - elif args.logger == "wandb": + elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name) + logger = WandbLogger(name=model.output_dir.name, project=dataset) - elif args.logger == "wandb_shared": + elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name) + logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") trainer: pl.Trainer = generic_train( model, args, @@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule: model.hparams.test_checkpoint = checkpoints[-1] trainer.resume_from_checkpoint = checkpoints[-1] trainer.logger.log_hyperparams(model.hparams) - trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics. + + # test() without a model tests using the best checkpoint automatically + trainer.test() return model if __name__ == "__main__": parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() main(args) diff --git a/examples/seq2seq/finetune.sh b/examples/seq2seq/finetune.sh index 89bd68eaae..78de9a3f74 100755 --- a/examples/seq2seq/finetune.sh +++ b/examples/seq2seq/finetune.sh @@ -10,5 +10,4 @@ python finetune.py \ --do_predict \ --n_val 1000 \ --val_check_interval 0.1 \ - --sortish_sampler \ $@ diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 0a6c59ab38..61f487d654 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { - "logger": "default", + "logger_name": "default", "length_penalty": 0.5, "cache_dir": "", "task": "summarization", @@ -48,7 +48,7 @@ CHEAP_ARGS = { "max_grad_norm": 1.0, "do_train": True, "do_predict": True, - "gradient_accumulation_steps": 1, + "accumulate_grad_batches": 1, "server_ip": "", "server_port": "", "seed": 42, @@ -60,7 +60,7 @@ CHEAP_ARGS = { "weight_decay": 0.0, "adam_epsilon": 1e-08, "warmup_steps": 0, - "num_train_epochs": 1, + "max_epochs": 1, "train_batch_size": 2, "eval_batch_size": 2, "max_source_length": 12, @@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase): updates = dict( student_encoder_layers=2, student_decoder_layers=1, - num_train_epochs=4, + max_epochs=4, val_check_interval=0.25, alpha_hid=2.0, model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED", @@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase): default_updates = dict( train_batch_size=1, eval_batch_size=2, - num_train_epochs=2, + max_epochs=2, alpha_mlm=0.2, alpha_ce=0.8, do_predict=True, @@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase): self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"]) self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float) - desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1) + desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1) self.assertEqual(len(metrics["val"]), desired_n_evals) self.assertEqual(len(metrics["test"]), 1) return model diff --git a/examples/seq2seq/train_mbart_cc25_enro.sh b/examples/seq2seq/train_mbart_cc25_enro.sh index 2fd5268cd4..4dcbe9ec1b 100755 --- a/examples/seq2seq/train_mbart_cc25_enro.sh +++ b/examples/seq2seq/train_mbart_cc25_enro.sh @@ -17,5 +17,5 @@ python finetune.py \ --model_name_or_path facebook/mbart-large-cc25 \ --task translation \ --warmup_steps 500 \ - --logger wandb --sortish_sampler \ + --logger_name wandb --sortish_sampler \ $@