[s2s] Adafactor support for builtin trainer (#7522)
This commit is contained in:
@@ -52,6 +52,7 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
|||||||
predict_with_generate: bool = field(
|
predict_with_generate: bool = field(
|
||||||
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||||
)
|
)
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
|
|||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
|
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||||
from transformers.trainer import get_tpu_sampler
|
from transformers.trainer import get_tpu_sampler
|
||||||
|
|
||||||
|
|
||||||
@@ -28,6 +29,43 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
self.pad_token_id = self.config.pad_token_id
|
self.pad_token_id = self.config.pad_token_id
|
||||||
self.vocab_size = self.config.vocab_size
|
self.vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
|
"""
|
||||||
|
Setup the optimizer and the learning rate scheduler.
|
||||||
|
|
||||||
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||||
|
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||||
|
"""
|
||||||
|
if self.optimizer is None:
|
||||||
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
if self.args.adafactor:
|
||||||
|
self.optimizer = Adafactor(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
lr=self.args.learning_rate,
|
||||||
|
scale_parameter=False,
|
||||||
|
relative_step=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.optimizer = AdamW(
|
||||||
|
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.lr_scheduler is None:
|
||||||
|
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||||
|
)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs
|
|||||||
"0.1",
|
"0.1",
|
||||||
# "--eval_beams",
|
# "--eval_beams",
|
||||||
# "2",
|
# "2",
|
||||||
|
"--adafactor",
|
||||||
"--task",
|
"--task",
|
||||||
"translation",
|
"translation",
|
||||||
"--tgt_lang",
|
"--tgt_lang",
|
||||||
|
|||||||
Reference in New Issue
Block a user