Experimental support for fairscale ShardedDDP (#9139)
* Experimental stupport for fairscale ShardedDDP * Add import error if fairscale not available * Address review comments * Fix seq2seq trainer
This commit is contained in:
@@ -33,6 +33,7 @@ from .integrations import ( # isort: split
|
||||
hp_params,
|
||||
is_azureml_available,
|
||||
is_comet_available,
|
||||
is_fairscale_available,
|
||||
is_mlflow_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
@@ -153,6 +154,11 @@ if is_azureml_available():
|
||||
|
||||
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
||||
|
||||
if is_fairscale_available():
|
||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -285,6 +291,16 @@ class Trainer:
|
||||
if isinstance(eval_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||
|
||||
# Setup Sharded DDP training
|
||||
self.sharded_dpp = False
|
||||
if args.sharded_ddp:
|
||||
if args.local_rank == -1:
|
||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||
elif not is_fairscale_available():
|
||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||
else:
|
||||
self.sharded_dpp = True
|
||||
|
||||
# Mixed precision setup
|
||||
self.use_apex = False
|
||||
self.use_amp = False
|
||||
@@ -296,7 +312,7 @@ class Trainer:
|
||||
|
||||
if backend == "amp":
|
||||
self.use_amp = True
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
|
||||
else:
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
@@ -491,12 +507,21 @@ class Trainer:
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.args.learning_rate,
|
||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||
eps=self.args.adam_epsilon,
|
||||
)
|
||||
if self.sharded_dpp:
|
||||
self.optimizer = OSS(
|
||||
params=optimizer_grouped_parameters,
|
||||
optim=AdamW,
|
||||
lr=self.args.learning_rate,
|
||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||
eps=self.args.adam_epsilon,
|
||||
)
|
||||
else:
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.args.learning_rate,
|
||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||
eps=self.args.adam_epsilon,
|
||||
)
|
||||
if self.lr_scheduler is None:
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||
@@ -643,7 +668,9 @@ class Trainer:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.args.local_rank != -1:
|
||||
if self.sharded_dpp:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
elif self.args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.args.local_rank],
|
||||
@@ -654,8 +681,8 @@ class Trainer:
|
||||
else True
|
||||
),
|
||||
)
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
|
||||
# Train!
|
||||
if is_torch_tpu_available():
|
||||
@@ -895,6 +922,8 @@ class Trainer:
|
||||
self.save_model(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
if self.sharded_dpp:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
|
||||
Reference in New Issue
Block a user