Experimental support for fairscale ShardedDDP (#9139)

* Experimental stupport for fairscale ShardedDDP

* Add import error if fairscale not available

* Address review comments

* Fix seq2seq trainer
This commit is contained in:
Sylvain Gugger
2020-12-16 13:47:48 -05:00
committed by GitHub
parent 1c1a2ffbff
commit 9a67185344
4 changed files with 78 additions and 19 deletions

View File

@@ -20,6 +20,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
from transformers import PreTrainedModel, Trainer, logging
from transformers.file_utils import is_torch_tpu_available
from transformers.integrations import is_fairscale_available
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import (
Adafactor,
@@ -35,6 +36,10 @@ from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
if is_fairscale_available():
from fairscale.optim import OSS
logger = logging.get_logger(__name__)
arg_to_scheduler = {
@@ -99,18 +104,25 @@ class Seq2SeqTrainer(Trainer):
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
self.optimizer = Adafactor(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
scale_parameter=False,
relative_step=False,
)
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
self.optimizer = AdamW(
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)