[s2s] Adafactor support for builtin trainer (#7522)

This commit is contained in:
Sam Shleifer
2020-10-01 17:27:45 -04:00
committed by GitHub
parent d3a9601a11
commit de4d7b004a
3 changed files with 40 additions and 0 deletions

View File

@@ -7,6 +7,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.trainer import get_tpu_sampler
@@ -28,6 +29,43 @@ class Seq2SeqTrainer(Trainer):
self.pad_token_id = self.config.pad_token_id
self.vocab_size = self.config.vocab_size
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if self.args.adafactor:
self.optimizer = Adafactor(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
scale_parameter=False,
relative_step=False,
)
else:
self.optimizer = AdamW(
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None