remove SharedDDP as it is deprecated (#25702)

* remove SharedDDP as it was drepracated

* apply review suggestion

* make style

* Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer.

* remove the unnecessary conditional statement

* keep the logic of IPEX

* clean code

* mix precision setup & make fixup

---------

Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
statelesshz
2023-10-06 22:03:11 +08:00
committed by GitHub
parent e840aa67e8
commit 27597fea07
11 changed files with 39 additions and 299 deletions

View File

@@ -19,7 +19,6 @@ from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from transformers import PreTrainedModel, Trainer, logging
from transformers.integrations import is_fairscale_available
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import (
Adafactor,
@@ -36,10 +35,6 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available
if is_fairscale_available():
from fairscale.optim import OSS
logger = logging.get_logger(__name__)
arg_to_scheduler = {
@@ -118,14 +113,7 @@ class Seq2SeqTrainer(Trainer):
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)