Experimental support for fairscale ShardedDDP (#9139)
* Experimental stupport for fairscale ShardedDDP * Add import error if fairscale not available * Address review comments * Fix seq2seq trainer
This commit is contained in:
@@ -20,6 +20,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
|
||||
|
||||
from transformers import PreTrainedModel, Trainer, logging
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
@@ -35,6 +36,10 @@ from transformers.trainer_pt_utils import get_tpu_sampler
|
||||
from transformers.training_args import ParallelMode
|
||||
|
||||
|
||||
if is_fairscale_available():
|
||||
from fairscale.optim import OSS
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
arg_to_scheduler = {
|
||||
@@ -99,18 +104,25 @@ class Seq2SeqTrainer(Trainer):
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer_cls = Adafactor if self.args.adafactor else AdamW
|
||||
if self.args.adafactor:
|
||||
self.optimizer = Adafactor(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.args.learning_rate,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
)
|
||||
|
||||
optimizer_cls = Adafactor
|
||||
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
|
||||
else:
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs = {
|
||||
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
||||
"eps": self.args.adam_epsilon,
|
||||
}
|
||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||
if self.sharded_dpp:
|
||||
self.optimizer = OSS(
|
||||
params=optimizer_grouped_parameters,
|
||||
optim=optimizer_cls,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
|
||||
if self.lr_scheduler is None:
|
||||
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
||||
|
||||
Reference in New Issue
Block a user