PL: --adafactor option (#6776)
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import (
|
|||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
)
|
)
|
||||||
from transformers.optimization import (
|
from transformers.optimization import (
|
||||||
|
Adafactor,
|
||||||
get_cosine_schedule_with_warmup,
|
get_cosine_schedule_with_warmup,
|
||||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
@@ -137,7 +138,15 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
if self.hparams.adafactor:
|
||||||
|
optimizer = Adafactor(
|
||||||
|
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
optimizer = AdamW(
|
||||||
|
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
||||||
|
)
|
||||||
self.opt = optimizer
|
self.opt = optimizer
|
||||||
|
|
||||||
scheduler = self.get_lr_scheduler()
|
scheduler = self.get_lr_scheduler()
|
||||||
@@ -251,6 +260,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||||
|
parser.add_argument("--adafactor", action="store_true")
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallback(pl.Callback):
|
class LoggingCallback(pl.Callback):
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ logger = logging.getLogger()
|
|||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
"label_smoothing": 0.2,
|
"label_smoothing": 0.2,
|
||||||
|
"adafactor": True,
|
||||||
"early_stopping_patience": 2,
|
"early_stopping_patience": 2,
|
||||||
"logger_name": "default",
|
"logger_name": "default",
|
||||||
"length_penalty": 0.5,
|
"length_penalty": 0.5,
|
||||||
|
|||||||
Reference in New Issue
Block a user