From 376c02e9a9196ee1ebb596dd588fc2c89450905a Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 5 Aug 2020 06:01:17 -0700 Subject: [PATCH] [WIP] lightning_base: support --lr_scheduler with multiple possibilities (#6232) * support --lr_scheduler with multiple possibilities * correct the error message * add a note about supported schedulers * cleanup * cleanup2 * needs the argument default * style * add another assert in the test * implement requested changes * cleanups * fix relative import * cleanup --- examples/lightning_base.py | 39 ++++++++++++-- examples/seq2seq/test_seq2seq_examples.py | 64 ++++++++++++++++++++++- 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index ae03e29561..8543571830 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -20,6 +20,10 @@ from transformers import ( AutoTokenizer, PretrainedConfig, PreTrainedTokenizer, +) +from transformers.optimization import ( + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, ) @@ -39,6 +43,19 @@ MODEL_MODES = { } +# update this and the import above to support new schedulers from transformers.optimization +arg_to_scheduler = { + "linear": get_linear_schedule_with_warmup, + "cosine": get_cosine_schedule_with_warmup, + "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, + # polynomial': '', # TODO + # '': get_constant_schedule, # not supported for now + # '': get_constant_schedule_with_warmup, # not supported for now +} +arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) +arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" + + class BaseTransformer(pl.LightningModule): def __init__( self, @@ -97,6 +114,14 @@ class BaseTransformer(pl.LightningModule): def load_hf_checkpoint(self, *args, **kwargs): self.model = self.model_type.from_pretrained(*args, **kwargs) + def get_lr_scheduler(self): + get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] + scheduler = get_schedule_func( + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps + ) + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + return scheduler + def configure_optimizers(self): """Prepare optimizer and schedule (linear warmup and decay)""" model = self.model @@ -114,10 +139,8 @@ class BaseTransformer(pl.LightningModule): optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) self.opt = optimizer - scheduler = get_linear_schedule_with_warmup( - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps - ) - scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + scheduler = self.get_lr_scheduler() + return [optimizer], [scheduler] def test_step(self, batch, batch_nb): @@ -203,6 +226,14 @@ class BaseTransformer(pl.LightningModule): "--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config", ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument( + "--lr_scheduler", + default="linear", + choices=arg_to_scheduler_choices, + metavar=arg_to_scheduler_metavar, + type=str, + help="Learning rate scheduler", + ) parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index d12aa03493..7473e0a64b 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -8,15 +8,17 @@ from pathlib import Path from unittest.mock import patch import pytest +import pytorch_lightning as pl import torch from pytest import param from torch.utils.data import DataLoader +import lightning_base from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer from transformers.testing_utils import require_multigpu from .distillation import distill_main, evaluate_checkpoint -from .finetune import main +from .finetune import SummarizationModule, main from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json @@ -60,6 +62,7 @@ CHEAP_ARGS = { "tokenizer_name": "facebook/bart-large", "do_lower_case": False, "learning_rate": 0.3, + "lr_scheduler": "linear", "weight_decay": 0.0, "adam_epsilon": 1e-08, "warmup_steps": 0, @@ -326,6 +329,65 @@ def test_finetune_extra_model_args(): assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" +def test_finetune_lr_shedulers(capsys): + args_d: dict = CHEAP_ARGS.copy() + + task = "summarization" + tmp_dir = make_test_data_dir() + + model = BART_TINY + output_dir = tempfile.mkdtemp(prefix="output_1_") + + args_d.update( + data_dir=tmp_dir, + model_name_or_path=model, + output_dir=output_dir, + tokenizer_name=None, + train_batch_size=2, + eval_batch_size=2, + do_predict=False, + task=task, + src_lang="en_XX", + tgt_lang="ro_RO", + freeze_encoder=True, + freeze_embeds=True, + ) + + # emulate finetune.py + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) + args = {"--help": True} + + # --help test + with pytest.raises(SystemExit) as excinfo: + args = parser.parse_args(args) + assert False, "--help is expected to sys.exit" + assert excinfo.type == SystemExit + captured = capsys.readouterr() + expected = lightning_base.arg_to_scheduler_metavar + assert expected in captured.out, "--help is expected to list the supported schedulers" + + # --lr_scheduler=non_existing_scheduler test + unsupported_param = "non_existing_scheduler" + args = {f"--lr_scheduler={unsupported_param}"} + with pytest.raises(SystemExit) as excinfo: + args = parser.parse_args(args) + assert False, "invalid argument is expected to sys.exit" + assert excinfo.type == SystemExit + captured = capsys.readouterr() + expected = f"invalid choice: '{unsupported_param}'" + assert expected in captured.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" + + # --lr_scheduler=existing_scheduler test + supported_param = "cosine" + args_d1 = args_d.copy() + args_d1["lr_scheduler"] = supported_param + args = argparse.Namespace(**args_d1) + model = main(args) + assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail" + + def test_pack_dataset(): tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")