Seq2seq trainer (#9241)

* Add label smoothing in Trainer

* Add options for scheduler and Adafactor in Trainer

* Put Seq2SeqTrainer in the main lib

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments and adapt scripts

* Documentation

* Move test not using script to tests folder

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-12-22 11:33:44 -05:00
committed by GitHub
parent 1fc7119181
commit 490b39e614
20 changed files with 655 additions and 166 deletions

View File

@@ -57,7 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import AdamW, get_linear_schedule_with_warmup
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
@@ -70,6 +70,7 @@ from .trainer_callback import (
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
LabelSmoother,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
@@ -320,6 +321,12 @@ class Trainer:
)
self.use_apex = True
# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None
self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
@@ -507,24 +514,32 @@ class Trainer:
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=AdamW,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
self.optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=num_training_steps,
)
def num_examples(self, dataloader: DataLoader) -> int:
@@ -1168,8 +1183,12 @@ class Trainer:
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.label_smoother is not None and "labels" in inputs:
return self.label_smoother(outputs, inputs["labels"])
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def is_local_process_zero(self) -> bool:
"""
@@ -1556,11 +1575,13 @@ class Trainer:
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None and "labels" in inputs:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None