PL: --adafactor option (#6776)
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
@@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule):
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
if self.hparams.adafactor:
|
||||
optimizer = Adafactor(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
||||
)
|
||||
|
||||
else:
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
||||
)
|
||||
self.opt = optimizer
|
||||
|
||||
scheduler = self.get_lr_scheduler()
|
||||
@@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||
parser.add_argument("--adafactor", action="store_true")
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
|
||||
@@ -30,6 +30,7 @@ logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"label_smoothing": 0.2,
|
||||
"adafactor": True,
|
||||
"early_stopping_patience": 2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
|
||||
Reference in New Issue
Block a user