[WIP] lightning_base: support --lr_scheduler with multiple possibilities (#6232)
* support --lr_scheduler with multiple possibilities * correct the error message * add a note about supported schedulers * cleanup * cleanup2 * needs the argument default * style * add another assert in the test * implement requested changes * cleanups * fix relative import * cleanup
This commit is contained in:
@@ -20,6 +20,10 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from transformers.optimization import (
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
@@ -39,6 +43,19 @@ MODEL_MODES = {
|
||||
}
|
||||
|
||||
|
||||
# update this and the import above to support new schedulers from transformers.optimization
|
||||
arg_to_scheduler = {
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
"cosine": get_cosine_schedule_with_warmup,
|
||||
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
# polynomial': '', # TODO
|
||||
# '': get_constant_schedule, # not supported for now
|
||||
# '': get_constant_schedule_with_warmup, # not supported for now
|
||||
}
|
||||
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,6 +114,14 @@ class BaseTransformer(pl.LightningModule):
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
||||
scheduler = get_schedule_func(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
||||
)
|
||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||
return scheduler
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||
model = self.model
|
||||
@@ -114,10 +139,8 @@ class BaseTransformer(pl.LightningModule):
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
self.opt = optimizer
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
||||
)
|
||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||
scheduler = self.get_lr_scheduler()
|
||||
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
@@ -203,6 +226,14 @@ class BaseTransformer(pl.LightningModule):
|
||||
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
default="linear",
|
||||
choices=arg_to_scheduler_choices,
|
||||
metavar=arg_to_scheduler_metavar,
|
||||
type=str,
|
||||
help="Learning rate scheduler",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
@@ -8,15 +8,17 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytest import param
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import lightning_base
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
|
||||
from transformers.testing_utils import require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .finetune import SummarizationModule, main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
@@ -60,6 +62,7 @@ CHEAP_ARGS = {
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 0.3,
|
||||
"lr_scheduler": "linear",
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
@@ -326,6 +329,65 @@ def test_finetune_extra_model_args():
|
||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||
|
||||
|
||||
def test_finetune_lr_shedulers(capsys):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir()
|
||||
|
||||
model = BART_TINY
|
||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# emulate finetune.py
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = {"--help": True}
|
||||
|
||||
# --help test
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "--help is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
captured = capsys.readouterr()
|
||||
expected = lightning_base.arg_to_scheduler_metavar
|
||||
assert expected in captured.out, "--help is expected to list the supported schedulers"
|
||||
|
||||
# --lr_scheduler=non_existing_scheduler test
|
||||
unsupported_param = "non_existing_scheduler"
|
||||
args = {f"--lr_scheduler={unsupported_param}"}
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "invalid argument is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
captured = capsys.readouterr()
|
||||
expected = f"invalid choice: '{unsupported_param}'"
|
||||
assert expected in captured.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||
|
||||
# --lr_scheduler=existing_scheduler test
|
||||
supported_param = "cosine"
|
||||
args_d1 = args_d.copy()
|
||||
args_d1["lr_scheduler"] = supported_param
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
|
||||
|
||||
|
||||
def test_pack_dataset():
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user