PL: --adafactor option (#6776)

This commit is contained in:
Sam Shleifer
2020-08-27 22:19:46 -04:00
committed by GitHub
parent 92ac2fa7d1
commit fb78a90d6a
2 changed files with 12 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from transformers import (
PreTrainedTokenizer,
)
from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
@@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule):
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer
scheduler = self.get_lr_scheduler()
@@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
class LoggingCallback(pl.Callback):

View File

@@ -30,6 +30,7 @@ logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing": 0.2,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,