Experimental support for fairscale ShardedDDP (#9139)
* Experimental stupport for fairscale ShardedDDP * Add import error if fairscale not available * Address review comments * Fix seq2seq trainer
This commit is contained in:
@@ -20,6 +20,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
|
|||||||
|
|
||||||
from transformers import PreTrainedModel, Trainer, logging
|
from transformers import PreTrainedModel, Trainer, logging
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
|
from transformers.integrations import is_fairscale_available
|
||||||
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
||||||
from transformers.optimization import (
|
from transformers.optimization import (
|
||||||
Adafactor,
|
Adafactor,
|
||||||
@@ -35,6 +36,10 @@ from transformers.trainer_pt_utils import get_tpu_sampler
|
|||||||
from transformers.training_args import ParallelMode
|
from transformers.training_args import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
|
if is_fairscale_available():
|
||||||
|
from fairscale.optim import OSS
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
arg_to_scheduler = {
|
arg_to_scheduler = {
|
||||||
@@ -99,18 +104,25 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
optimizer_cls = Adafactor if self.args.adafactor else AdamW
|
||||||
if self.args.adafactor:
|
if self.args.adafactor:
|
||||||
self.optimizer = Adafactor(
|
optimizer_cls = Adafactor
|
||||||
optimizer_grouped_parameters,
|
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
|
||||||
lr=self.args.learning_rate,
|
|
||||||
scale_parameter=False,
|
|
||||||
relative_step=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.optimizer = AdamW(
|
optimizer_cls = AdamW
|
||||||
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon
|
optimizer_kwargs = {
|
||||||
|
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
||||||
|
"eps": self.args.adam_epsilon,
|
||||||
|
}
|
||||||
|
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||||
|
if self.sharded_dpp:
|
||||||
|
self.optimizer = OSS(
|
||||||
|
params=optimizer_grouped_parameters,
|
||||||
|
optim=optimizer_cls,
|
||||||
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
|
||||||
if self.lr_scheduler is None:
|
if self.lr_scheduler is None:
|
||||||
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
||||||
|
|||||||
@@ -92,6 +92,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_has_mlflow = False
|
_has_mlflow = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fairscale # noqa: F401
|
||||||
|
|
||||||
|
_has_fairscale = True
|
||||||
|
except ImportError:
|
||||||
|
_has_fairscale = False
|
||||||
|
|
||||||
# No transformer imports above this point
|
# No transformer imports above this point
|
||||||
|
|
||||||
from .file_utils import is_torch_tpu_available # noqa: E402
|
from .file_utils import is_torch_tpu_available # noqa: E402
|
||||||
@@ -128,6 +135,10 @@ def is_mlflow_available():
|
|||||||
return _has_mlflow
|
return _has_mlflow
|
||||||
|
|
||||||
|
|
||||||
|
def is_fairscale_available():
|
||||||
|
return _has_fairscale
|
||||||
|
|
||||||
|
|
||||||
def hp_params(trial):
|
def hp_params(trial):
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
if isinstance(trial, optuna.Trial):
|
if isinstance(trial, optuna.Trial):
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from .integrations import ( # isort: split
|
|||||||
hp_params,
|
hp_params,
|
||||||
is_azureml_available,
|
is_azureml_available,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
|
is_fairscale_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_available,
|
||||||
@@ -153,6 +154,11 @@ if is_azureml_available():
|
|||||||
|
|
||||||
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
||||||
|
|
||||||
|
if is_fairscale_available():
|
||||||
|
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||||
|
from fairscale.optim import OSS
|
||||||
|
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -285,6 +291,16 @@ class Trainer:
|
|||||||
if isinstance(eval_dataset, datasets.Dataset):
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||||
|
|
||||||
|
# Setup Sharded DDP training
|
||||||
|
self.sharded_dpp = False
|
||||||
|
if args.sharded_ddp:
|
||||||
|
if args.local_rank == -1:
|
||||||
|
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||||
|
elif not is_fairscale_available():
|
||||||
|
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||||
|
else:
|
||||||
|
self.sharded_dpp = True
|
||||||
|
|
||||||
# Mixed precision setup
|
# Mixed precision setup
|
||||||
self.use_apex = False
|
self.use_apex = False
|
||||||
self.use_amp = False
|
self.use_amp = False
|
||||||
@@ -296,7 +312,7 @@ class Trainer:
|
|||||||
|
|
||||||
if backend == "amp":
|
if backend == "amp":
|
||||||
self.use_amp = True
|
self.use_amp = True
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
|
||||||
else:
|
else:
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -491,12 +507,21 @@ class Trainer:
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
self.optimizer = AdamW(
|
if self.sharded_dpp:
|
||||||
optimizer_grouped_parameters,
|
self.optimizer = OSS(
|
||||||
lr=self.args.learning_rate,
|
params=optimizer_grouped_parameters,
|
||||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
optim=AdamW,
|
||||||
eps=self.args.adam_epsilon,
|
lr=self.args.learning_rate,
|
||||||
)
|
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||||
|
eps=self.args.adam_epsilon,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.optimizer = AdamW(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
lr=self.args.learning_rate,
|
||||||
|
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||||
|
eps=self.args.adam_epsilon,
|
||||||
|
)
|
||||||
if self.lr_scheduler is None:
|
if self.lr_scheduler is None:
|
||||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||||
@@ -643,7 +668,9 @@ class Trainer:
|
|||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if self.args.local_rank != -1:
|
if self.sharded_dpp:
|
||||||
|
model = ShardedDDP(model, self.optimizer)
|
||||||
|
elif self.args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[self.args.local_rank],
|
device_ids=[self.args.local_rank],
|
||||||
@@ -654,8 +681,8 @@ class Trainer:
|
|||||||
else True
|
else True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# find_unused_parameters breaks checkpointing as per
|
# find_unused_parameters breaks checkpointing as per
|
||||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -895,6 +922,8 @@ class Trainer:
|
|||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
|
|
||||||
# Save optimizer and scheduler
|
# Save optimizer and scheduler
|
||||||
|
if self.sharded_dpp:
|
||||||
|
self.optimizer.consolidate_state_dict()
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous("saving_optimizer_states")
|
xm.rendezvous("saving_optimizer_states")
|
||||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
|||||||
@@ -215,6 +215,9 @@ class TrainingArguments:
|
|||||||
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
|
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
|
||||||
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
|
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
|
||||||
other choices will force the requested backend.
|
other choices will force the requested backend.
|
||||||
|
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
||||||
|
training only). This is an experimental feature.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -386,6 +389,10 @@ class TrainingArguments:
|
|||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
|
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
|
||||||
)
|
)
|
||||||
|
sharded_ddp: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user