From fb78a90d6a28ab8b730bea8c99be91b75f7e041e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 27 Aug 2020 22:19:46 -0400 Subject: [PATCH] PL: --adafactor option (#6776) --- examples/lightning_base.py | 12 +++++++++++- examples/seq2seq/test_seq2seq_examples.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index d23757a9bc..4c8d3649f9 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -22,6 +22,7 @@ from transformers import ( PreTrainedTokenizer, ) from transformers.optimization import ( + Adafactor, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, @@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule): "weight_decay": 0.0, }, ] - optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) + if self.hparams.adafactor: + optimizer = Adafactor( + optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False + ) + + else: + optimizer = AdamW( + optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon + ) self.opt = optimizer scheduler = self.get_lr_scheduler() @@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule): parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) parser.add_argument("--train_batch_size", default=32, type=int) parser.add_argument("--eval_batch_size", default=32, type=int) + parser.add_argument("--adafactor", action="store_true") class LoggingCallback(pl.Callback): diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 2f397c7adc..f853557f18 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -30,6 +30,7 @@ logger = logging.getLogger() CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { "label_smoothing": 0.2, + "adafactor": True, "early_stopping_patience": 2, "logger_name": "default", "length_penalty": 0.5,