Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)
* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale * Quality * Rework from review comments * Add doc * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper <https://arxiv.
|
|||||||
|
|
||||||
1. Optimizer State Sharding
|
1. Optimizer State Sharding
|
||||||
2. Gradient Sharding
|
2. Gradient Sharding
|
||||||
|
3. Model Parameters Sharding (new and very experimental)
|
||||||
|
4. CPU offload (new and very experimental)
|
||||||
|
|
||||||
You will need at least two GPUs to use this feature.
|
You will need at least two GPUs to use this feature.
|
||||||
|
|
||||||
@@ -255,8 +257,9 @@ To deploy this feature:
|
|||||||
or find more details on `the FairScale's GitHub page
|
or find more details on `the FairScale's GitHub page
|
||||||
<https://github.com/facebookresearch/fairscale/#installation>`__.
|
<https://github.com/facebookresearch/fairscale/#installation>`__.
|
||||||
|
|
||||||
2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m
|
2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments,
|
||||||
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
and make sure you have added the distributed launcher ``-m torch.distributed.launch
|
||||||
|
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||||
|
|
||||||
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
||||||
|
|
||||||
@@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
|||||||
--do_train --max_train_samples 500 --num_train_epochs 1 \
|
--do_train --max_train_samples 500 --num_train_epochs 1 \
|
||||||
--dataset_name wmt16 --dataset_config "ro-en" \
|
--dataset_name wmt16 --dataset_config "ro-en" \
|
||||||
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
|
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
|
||||||
--fp16 --sharded_ddp
|
--fp16 --sharded_ddp simple
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
|
|
||||||
- This feature requires distributed training (so multiple GPUs).
|
- This feature requires distributed training (so multiple GPUs).
|
||||||
- It is not implemented for TPUs.
|
- It is not implemented for TPUs.
|
||||||
- It works with ``--fp16`` too, to make things even faster.
|
- It works with ``--fp16`` too, to make things even faster.
|
||||||
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able
|
- One of the main benefits of enabling ``--sharded_ddp simple`` is that it uses a lot less GPU memory, so you should be
|
||||||
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
|
able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
|
||||||
significantly shorter training time.
|
significantly shorter training time.
|
||||||
|
|
||||||
|
3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3`
|
||||||
|
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
|
||||||
|
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||||
|
|
||||||
|
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
|
||||||
|
--model_name_or_path t5-small --per_device_train_batch_size 1 \
|
||||||
|
--output_dir output_dir --overwrite_output_dir \
|
||||||
|
--do_train --max_train_samples 500 --num_train_epochs 1 \
|
||||||
|
--dataset_name wmt16 --dataset_config "ro-en" \
|
||||||
|
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
|
||||||
|
--fp16 --sharded_ddp zero_dp_2
|
||||||
|
|
||||||
|
:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
|
||||||
|
gradients and optimizer states.
|
||||||
|
|
||||||
|
Both are compatible with adding :obj:`cpu_offload` to enable ZeRO-offload (activate it like this: :obj:`--sharded_ddp
|
||||||
|
"zero_dp_2 cpu_offload"`).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- This feature requires distributed training (so multiple GPUs).
|
||||||
|
- It is not implemented for TPUs.
|
||||||
|
- It works with ``--fp16`` too, to make things even faster.
|
||||||
|
- The ``cpu_offload`` additional option requires ``--fp16``.
|
||||||
|
- This is an area of active development, so make sure you have a source install of fairscale to use this feature as
|
||||||
|
some bugs you encounter may have been fixed there already.
|
||||||
|
|
||||||
|
Known caveats:
|
||||||
|
|
||||||
|
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
|
||||||
|
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
|
||||||
|
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the
|
||||||
|
:class:`~transformers.Trainer`.
|
||||||
|
|
||||||
|
|
||||||
DeepSpeed
|
DeepSpeed
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|||||||
@@ -64,11 +64,12 @@ def require_apex(test_case):
|
|||||||
|
|
||||||
|
|
||||||
class TestTrainerExt(TestCasePlus):
|
class TestTrainerExt(TestCasePlus):
|
||||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
|
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
|
||||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
|
if predict_with_generate:
|
||||||
assert "eval_bleu" in first_step_stats
|
assert "eval_bleu" in first_step_stats
|
||||||
|
|
||||||
@require_torch_non_multi_gpu
|
@require_torch_non_multi_gpu
|
||||||
@@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
# test --sharded_ddp w/o --fp16
|
# test --sharded_ddp w/o --fp16
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@require_fairscale
|
@require_fairscale
|
||||||
def test_run_seq2seq_ddp_sharded_ddp(self):
|
def test_run_seq2seq_sharded_ddp(self):
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
|
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
|
||||||
|
|
||||||
# test --sharded_ddp w/ --fp16
|
# test --sharded_ddp w/ --fp16
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@require_fairscale
|
@require_fairscale
|
||||||
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
|
def test_run_seq2seq_sharded_ddp_fp16(self):
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
||||||
|
|
||||||
|
# test --sharded_ddp zero2 w/o --fp16
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_fairscale
|
||||||
|
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||||
|
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
|
||||||
|
|
||||||
|
# test --sharded_ddp zero2 w/ --fp16
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_fairscale
|
||||||
|
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||||
|
self.run_seq2seq_quick(
|
||||||
|
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
|
||||||
|
)
|
||||||
|
|
||||||
@require_apex
|
@require_apex
|
||||||
def test_run_seq2seq_apex(self):
|
def test_run_seq2seq_apex(self):
|
||||||
@@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
num_train_epochs: int,
|
num_train_epochs: int,
|
||||||
distributed: bool = False,
|
distributed: bool = False,
|
||||||
extra_args_str: str = None,
|
extra_args_str: str = None,
|
||||||
|
predict_with_generate: bool = True,
|
||||||
):
|
):
|
||||||
data_dir = self.examples_dir / "test_data/wmt_en_ro"
|
data_dir = self.examples_dir / "test_data/wmt_en_ro"
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
@@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
--learning_rate 3e-3
|
--learning_rate 3e-3
|
||||||
--warmup_steps 8
|
--warmup_steps 8
|
||||||
--evaluation_strategy steps
|
--evaluation_strategy steps
|
||||||
--predict_with_generate
|
|
||||||
--logging_steps 0
|
--logging_steps 0
|
||||||
--save_steps {str(eval_steps)}
|
--save_steps {str(eval_steps)}
|
||||||
--eval_steps {str(eval_steps)}
|
--eval_steps {str(eval_steps)}
|
||||||
@@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
--task translation
|
--task translation
|
||||||
--target_lang ro_RO
|
--target_lang ro_RO
|
||||||
--source_lang en_XX
|
--source_lang en_XX
|
||||||
""".split()
|
"""
|
||||||
|
if predict_with_generate:
|
||||||
|
args += "--predict_with_generate"
|
||||||
|
|
||||||
|
args = args.split()
|
||||||
|
|
||||||
if extra_args_str is not None:
|
if extra_args_str is not None:
|
||||||
args.extend(extra_args_str.split())
|
args.extend(extra_args_str.split())
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ from .trainer_utils import (
|
|||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
HPSearchBackend,
|
HPSearchBackend,
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
|
ShardedDDPOption,
|
||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
@@ -131,10 +132,16 @@ if is_torch_tpu_available():
|
|||||||
import torch_xla.distributed.parallel_loader as pl
|
import torch_xla.distributed.parallel_loader as pl
|
||||||
|
|
||||||
if is_fairscale_available():
|
if is_fairscale_available():
|
||||||
|
import fairscale
|
||||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||||
from fairscale.optim import OSS
|
from fairscale.optim import OSS
|
||||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||||
|
|
||||||
|
if version.parse(fairscale.__version__) >= version.parse("0.3"):
|
||||||
|
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
|
||||||
|
else:
|
||||||
|
FullyShardedDDP = None
|
||||||
|
|
||||||
if is_sagemaker_distributed_available():
|
if is_sagemaker_distributed_available():
|
||||||
import smdistributed.dataparallel.torch.distributed as dist
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
||||||
@@ -277,9 +284,38 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
self.is_model_parallel = False
|
self.is_model_parallel = False
|
||||||
|
|
||||||
|
# Setup Sharded DDP training
|
||||||
|
self.sharded_ddp = None
|
||||||
|
if len(args.sharded_ddp) > 0:
|
||||||
|
if args.deepspeed:
|
||||||
|
raise ValueError(
|
||||||
|
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.local_rank == -1:
|
||||||
|
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||||
|
elif not is_fairscale_available():
|
||||||
|
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||||
|
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
|
||||||
|
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
|
||||||
|
)
|
||||||
|
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
|
||||||
|
self.sharded_ddp = ShardedDDPOption.SIMPLE
|
||||||
|
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
|
||||||
|
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
|
||||||
|
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||||
|
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||||
|
|
||||||
# one place to sort out whether to place the model on device or not
|
# one place to sort out whether to place the model on device or not
|
||||||
self.place_model_on_device = args.place_model_on_device
|
self.place_model_on_device = args.place_model_on_device
|
||||||
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
|
if (
|
||||||
|
self.is_model_parallel
|
||||||
|
or (args.deepspeed and args.do_train)
|
||||||
|
or (args.fp16_full_eval and not args.do_train)
|
||||||
|
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||||
|
):
|
||||||
self.place_model_on_device = False
|
self.place_model_on_device = False
|
||||||
|
|
||||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
@@ -346,21 +382,6 @@ class Trainer:
|
|||||||
if isinstance(eval_dataset, datasets.Dataset):
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||||
|
|
||||||
# Setup Sharded DDP training
|
|
||||||
self.sharded_dpp = False
|
|
||||||
if args.sharded_ddp:
|
|
||||||
if args.deepspeed:
|
|
||||||
raise ValueError(
|
|
||||||
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.local_rank == -1:
|
|
||||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
|
||||||
elif not is_fairscale_available():
|
|
||||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
|
||||||
else:
|
|
||||||
self.sharded_dpp = True
|
|
||||||
|
|
||||||
# Mixed precision setup
|
# Mixed precision setup
|
||||||
self.use_apex = False
|
self.use_apex = False
|
||||||
self.use_amp = False
|
self.use_amp = False
|
||||||
@@ -376,7 +397,7 @@ class Trainer:
|
|||||||
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
|
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
|
||||||
if self.fp16_backend == "amp":
|
if self.fp16_backend == "amp":
|
||||||
self.use_amp = True
|
self.use_amp = True
|
||||||
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
|
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
|
||||||
else:
|
else:
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -619,7 +640,7 @@ class Trainer:
|
|||||||
"eps": self.args.adam_epsilon,
|
"eps": self.args.adam_epsilon,
|
||||||
}
|
}
|
||||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||||
if self.sharded_dpp:
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||||
self.optimizer = OSS(
|
self.optimizer = OSS(
|
||||||
params=optimizer_grouped_parameters,
|
params=optimizer_grouped_parameters,
|
||||||
optim=optimizer_cls,
|
optim=optimizer_cls,
|
||||||
@@ -737,8 +758,19 @@ class Trainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if self.sharded_dpp:
|
if self.sharded_ddp is not None:
|
||||||
|
# Sharded DDP!
|
||||||
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||||
model = ShardedDDP(model, self.optimizer)
|
model = ShardedDDP(model, self.optimizer)
|
||||||
|
else:
|
||||||
|
mixed_precision = self.args.fp16
|
||||||
|
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
||||||
|
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
||||||
|
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||||
|
self.model = model = FullyShardedDDP(
|
||||||
|
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
|
||||||
|
).to(self.args.device)
|
||||||
|
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_distributed_available():
|
||||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
@@ -855,6 +887,7 @@ class Trainer:
|
|||||||
num_train_epochs = 1
|
num_train_epochs = 1
|
||||||
num_update_steps_per_epoch = max_steps
|
num_update_steps_per_epoch = max_steps
|
||||||
|
|
||||||
|
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||||
if self.args.deepspeed:
|
if self.args.deepspeed:
|
||||||
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
|
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
|
||||||
self.model = model.module
|
self.model = model.module
|
||||||
@@ -862,7 +895,7 @@ class Trainer:
|
|||||||
self.deepspeed = model # DeepSpeedEngine object
|
self.deepspeed = model # DeepSpeedEngine object
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.lr_scheduler = lr_scheduler
|
self.lr_scheduler = lr_scheduler
|
||||||
else:
|
elif not delay_optimizer_creation:
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
@@ -877,6 +910,9 @@ class Trainer:
|
|||||||
if model is not self.model:
|
if model is not self.model:
|
||||||
self.model_wrapped = model
|
self.model_wrapped = model
|
||||||
|
|
||||||
|
if delay_optimizer_creation:
|
||||||
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
# important: at this point:
|
# important: at this point:
|
||||||
# self.model is the Transformers Model
|
# self.model is the Transformers Model
|
||||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
||||||
@@ -1026,6 +1062,9 @@ class Trainer:
|
|||||||
if hasattr(self.optimizer, "clip_grad_norm"):
|
if hasattr(self.optimizer, "clip_grad_norm"):
|
||||||
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
||||||
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
|
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
|
||||||
|
elif hasattr(model, "clip_grad_norm_"):
|
||||||
|
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
|
||||||
|
model.clip_grad_norm_(self.args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(
|
||||||
@@ -1148,8 +1187,8 @@ class Trainer:
|
|||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
# want to save.
|
# want to save except FullyShardedDDP.
|
||||||
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
||||||
|
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
@@ -1173,7 +1212,7 @@ class Trainer:
|
|||||||
self.deepspeed.save_checkpoint(output_dir)
|
self.deepspeed.save_checkpoint(output_dir)
|
||||||
|
|
||||||
# Save optimizer and scheduler
|
# Save optimizer and scheduler
|
||||||
if self.sharded_dpp:
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||||
self.optimizer.consolidate_state_dict()
|
self.optimizer.consolidate_state_dict()
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -1479,6 +1518,10 @@ class Trainer:
|
|||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
xm.rendezvous("saving_checkpoint")
|
xm.rendezvous("saving_checkpoint")
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
|
if isinstance(_model_unwrap(self.model), PreTrainedModel):
|
||||||
|
if xm.is_master_ordinal():
|
||||||
|
_model_unwrap(self.model).config.save_pretrained(output_dir)
|
||||||
|
else:
|
||||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
@@ -1494,6 +1537,9 @@ class Trainer:
|
|||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
|
if isinstance(_model_unwrap(self.model), PreTrainedModel):
|
||||||
|
_model_unwrap(self.model).config.save_pretrained(output_dir)
|
||||||
|
else:
|
||||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
|||||||
@@ -421,3 +421,10 @@ class TrainerMemoryTracker:
|
|||||||
# init doesn't have metrics to update so we just save that data for later stages to retrieve
|
# init doesn't have metrics to update so we just save that data for later stages to retrieve
|
||||||
if metrics is not None:
|
if metrics is not None:
|
||||||
self.update_metrics(stage, metrics)
|
self.update_metrics(stage, metrics)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedDDPOption(ExplicitEnum):
|
||||||
|
SIMPLE = "simple"
|
||||||
|
ZERO_DP_2 = "zero2"
|
||||||
|
ZERO_DP_3 = "zero3"
|
||||||
|
OFFLOAD = "offload"
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from .file_utils import (
|
|||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
|
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -236,9 +236,22 @@ class TrainingArguments:
|
|||||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||||
step can take a long time) but will not yield the same results as the interrupted training would have.
|
step can take a long time) but will not yield the same results as the interrupted training would have.
|
||||||
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`):
|
||||||
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
||||||
training only). This is an experimental feature.
|
training only). This is an experimental feature.
|
||||||
|
|
||||||
|
A list of options along the following:
|
||||||
|
|
||||||
|
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
|
||||||
|
to ZeRO-2.
|
||||||
|
- :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale
|
||||||
|
(:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`).
|
||||||
|
- :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale
|
||||||
|
(:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`).
|
||||||
|
- :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`).
|
||||||
|
|
||||||
|
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
||||||
|
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`.
|
||||||
deepspeed (:obj:`str`, `optional`):
|
deepspeed (:obj:`str`, `optional`):
|
||||||
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
|
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
|
||||||
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
|
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
|
||||||
@@ -443,9 +456,14 @@ class TrainingArguments:
|
|||||||
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
sharded_ddp: bool = field(
|
sharded_ddp: str = field(
|
||||||
default=False,
|
default="",
|
||||||
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
|
metadata={
|
||||||
|
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
|
||||||
|
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
|
||||||
|
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
|
||||||
|
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
deepspeed: Optional[str] = field(
|
deepspeed: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -535,6 +553,20 @@ class TrainingArguments:
|
|||||||
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
|
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(self.sharded_ddp, bool):
|
||||||
|
self.sharded_ddp = "simple" if self.sharded_ddp else ""
|
||||||
|
if isinstance(self.sharded_ddp, str):
|
||||||
|
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
|
||||||
|
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
|
||||||
|
raise ValueError(
|
||||||
|
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
|
||||||
|
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
|
||||||
|
)
|
||||||
|
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
|
||||||
|
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
|
||||||
|
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||||
|
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||||
@@ -662,7 +694,7 @@ class TrainingArguments:
|
|||||||
|
|
||||||
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
|
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
|
||||||
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
|
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
|
||||||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses
|
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
|
||||||
:obj:`torch.nn.DistributedDataParallel`).
|
:obj:`torch.nn.DistributedDataParallel`).
|
||||||
- :obj:`ParallelMode.TPU`: several TPU cores.
|
- :obj:`ParallelMode.TPU`: several TPU cores.
|
||||||
"""
|
"""
|
||||||
@@ -692,6 +724,8 @@ class TrainingArguments:
|
|||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if isinstance(v, Enum):
|
if isinstance(v, Enum):
|
||||||
d[k] = v.value
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user