[WIP] lightning_base: support --lr_scheduler with multiple possibilities (#6232)
* support --lr_scheduler with multiple possibilities * correct the error message * add a note about supported schedulers * cleanup * cleanup2 * needs the argument default * style * add another assert in the test * implement requested changes * cleanups * fix relative import * cleanup
This commit is contained in:
@@ -20,6 +20,10 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
|
)
|
||||||
|
from transformers.optimization import (
|
||||||
|
get_cosine_schedule_with_warmup,
|
||||||
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,6 +43,19 @@ MODEL_MODES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# update this and the import above to support new schedulers from transformers.optimization
|
||||||
|
arg_to_scheduler = {
|
||||||
|
"linear": get_linear_schedule_with_warmup,
|
||||||
|
"cosine": get_cosine_schedule_with_warmup,
|
||||||
|
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
# polynomial': '', # TODO
|
||||||
|
# '': get_constant_schedule, # not supported for now
|
||||||
|
# '': get_constant_schedule_with_warmup, # not supported for now
|
||||||
|
}
|
||||||
|
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||||
|
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformer(pl.LightningModule):
|
class BaseTransformer(pl.LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -97,6 +114,14 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
def load_hf_checkpoint(self, *args, **kwargs):
|
def load_hf_checkpoint(self, *args, **kwargs):
|
||||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_lr_scheduler(self):
|
||||||
|
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
||||||
|
scheduler = get_schedule_func(
|
||||||
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
||||||
|
)
|
||||||
|
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||||
|
return scheduler
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||||
model = self.model
|
model = self.model
|
||||||
@@ -114,10 +139,8 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||||
self.opt = optimizer
|
self.opt = optimizer
|
||||||
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = self.get_lr_scheduler()
|
||||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
|
||||||
)
|
|
||||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
def test_step(self, batch, batch_nb):
|
def test_step(self, batch, batch_nb):
|
||||||
@@ -203,6 +226,14 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
|
"--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config",
|
||||||
)
|
)
|
||||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler",
|
||||||
|
default="linear",
|
||||||
|
choices=arg_to_scheduler_choices,
|
||||||
|
metavar=arg_to_scheduler_metavar,
|
||||||
|
type=str,
|
||||||
|
help="Learning rate scheduler",
|
||||||
|
)
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
|
|||||||
@@ -8,15 +8,17 @@ from pathlib import Path
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from pytest import param
|
from pytest import param
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import lightning_base
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
|
||||||
from transformers.testing_utils import require_multigpu
|
from transformers.testing_utils import require_multigpu
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import main
|
from .finetune import SummarizationModule, main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||||
@@ -60,6 +62,7 @@ CHEAP_ARGS = {
|
|||||||
"tokenizer_name": "facebook/bart-large",
|
"tokenizer_name": "facebook/bart-large",
|
||||||
"do_lower_case": False,
|
"do_lower_case": False,
|
||||||
"learning_rate": 0.3,
|
"learning_rate": 0.3,
|
||||||
|
"lr_scheduler": "linear",
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"adam_epsilon": 1e-08,
|
"adam_epsilon": 1e-08,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
@@ -326,6 +329,65 @@ def test_finetune_extra_model_args():
|
|||||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||||
|
|
||||||
|
|
||||||
|
def test_finetune_lr_shedulers(capsys):
|
||||||
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
|
task = "summarization"
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
|
||||||
|
model = BART_TINY
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||||
|
|
||||||
|
args_d.update(
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
model_name_or_path=model,
|
||||||
|
output_dir=output_dir,
|
||||||
|
tokenizer_name=None,
|
||||||
|
train_batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
do_predict=False,
|
||||||
|
task=task,
|
||||||
|
src_lang="en_XX",
|
||||||
|
tgt_lang="ro_RO",
|
||||||
|
freeze_encoder=True,
|
||||||
|
freeze_embeds=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# emulate finetune.py
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||||
|
args = {"--help": True}
|
||||||
|
|
||||||
|
# --help test
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
args = parser.parse_args(args)
|
||||||
|
assert False, "--help is expected to sys.exit"
|
||||||
|
assert excinfo.type == SystemExit
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected = lightning_base.arg_to_scheduler_metavar
|
||||||
|
assert expected in captured.out, "--help is expected to list the supported schedulers"
|
||||||
|
|
||||||
|
# --lr_scheduler=non_existing_scheduler test
|
||||||
|
unsupported_param = "non_existing_scheduler"
|
||||||
|
args = {f"--lr_scheduler={unsupported_param}"}
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
args = parser.parse_args(args)
|
||||||
|
assert False, "invalid argument is expected to sys.exit"
|
||||||
|
assert excinfo.type == SystemExit
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected = f"invalid choice: '{unsupported_param}'"
|
||||||
|
assert expected in captured.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||||
|
|
||||||
|
# --lr_scheduler=existing_scheduler test
|
||||||
|
supported_param = "cosine"
|
||||||
|
args_d1 = args_d.copy()
|
||||||
|
args_d1["lr_scheduler"] = supported_param
|
||||||
|
args = argparse.Namespace(**args_d1)
|
||||||
|
model = main(args)
|
||||||
|
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
|
||||||
|
|
||||||
|
|
||||||
def test_pack_dataset():
|
def test_pack_dataset():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user