From cd56f3fe7eae4a53a9880e3f5e8f91877a78271c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 31 Mar 2021 10:01:30 -0400 Subject: [PATCH] Merge trainers (#10975) * Replace is_sagemaker_distributed_available * Merge SageMakerTrainer into Trainer * Test with shorter condition * Put back deleted line * Deprecate SageMakerTrainer and SageMakerTrainingArguments * Apply suggestions from code review Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> --- src/transformers/file_utils.py | 26 +++- src/transformers/sagemaker/__init__.py | 2 +- src/transformers/sagemaker/trainer_sm.py | 5 + .../sagemaker/training_args_sm.py | 13 +- src/transformers/trainer.py | 118 ++++++++++++++---- src/transformers/trainer_pt_utils.py | 41 +++++- src/transformers/trainer_utils.py | 4 +- src/transformers/training_args.py | 41 ++++-- .../scripts/tensorflow/run_tf_dist.py | 4 +- 9 files changed, 210 insertions(+), 44 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 597435fad2..8e62eca94a 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -352,7 +352,7 @@ def is_pandas_available(): return importlib.util.find_spec("pandas") is not None -def is_sagemaker_distributed_available(): +def is_sagemaker_dp_enabled(): # Get the sagemaker specific env variable. sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") try: @@ -366,6 +366,30 @@ def is_sagemaker_distributed_available(): return importlib.util.find_spec("smdistributed") is not None +def is_sagemaker_mp_enabled(): + # Get the sagemaker specific mp parameters from smp_options variable. + smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") + try: + # Parse it and check the field "partitions" is included, it is required for model parallel. + smp_options = json.loads(smp_options) + if "partitions" not in smp_options: + return False + except json.JSONDecodeError: + return False + + # Get the sagemaker specific framework parameters from mpi_options variable. + mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") + try: + # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". + mpi_options = json.loads(mpi_options) + if not mpi_options.get("sagemaker_mpi_enabled", False): + return False + except json.JSONDecodeError: + return False + # Lastly, check if the `smdistributed` module is present. + return importlib.util.find_spec("smdistributed") is not None + + def is_training_run_on_sagemaker(): return "SAGEMAKER_JOB_NAME" in os.environ diff --git a/src/transformers/sagemaker/__init__.py b/src/transformers/sagemaker/__init__.py index 46222fdf7c..22bdaf2946 100644 --- a/src/transformers/sagemaker/__init__.py +++ b/src/transformers/sagemaker/__init__.py @@ -17,4 +17,4 @@ # limitations under the License. from .trainer_sm import SageMakerTrainer -from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available +from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled diff --git a/src/transformers/sagemaker/trainer_sm.py b/src/transformers/sagemaker/trainer_sm.py index 1ea9a8f40b..bc725fd647 100644 --- a/src/transformers/sagemaker/trainer_sm.py +++ b/src/transformers/sagemaker/trainer_sm.py @@ -79,6 +79,11 @@ if is_sagemaker_model_parallel_available(): class SageMakerTrainer(Trainer): def __init__(self, args=None, **kwargs): + warnings.warn( + "`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` " + "instead.", + FutureWarning, + ) self.is_model_parallel_enabled = is_sagemaker_model_parallel_available() super().__init__(args=args, **kwargs) diff --git a/src/transformers/sagemaker/training_args_sm.py b/src/transformers/sagemaker/training_args_sm.py index e6cbf8dd37..0a01c1dc0f 100644 --- a/src/transformers/sagemaker/training_args_sm.py +++ b/src/transformers/sagemaker/training_args_sm.py @@ -15,11 +15,12 @@ import importlib.util import json import os +import warnings from dataclasses import dataclass, field import torch -from transformers.file_utils import cached_property, is_sagemaker_distributed_available +from transformers.file_utils import cached_property, is_sagemaker_dp_enabled from transformers.training_args import TrainingArguments from transformers.utils import logging @@ -66,6 +67,14 @@ class SageMakerTrainingArguments(TrainingArguments): metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"}, ) + def __post_init__(self): + super().__post_init__() + warnings.warn( + "`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use " + "`TrainingArguments` instead.", + FutureWarning, + ) + @cached_property def _setup_devices(self) -> "torch.device": logger.info("PyTorch: setting up devices") @@ -76,7 +85,7 @@ class SageMakerTrainingArguments(TrainingArguments): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) self._n_gpu = 1 - elif is_sagemaker_distributed_available(): + elif is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as dist dist.init_process_group() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 27b1ed90fa..7c33981b6d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -59,7 +59,8 @@ from .file_utils import ( is_apex_available, is_datasets_available, is_in_notebook, - is_sagemaker_distributed_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, ) @@ -149,12 +150,17 @@ if is_fairscale_available(): else: FullyShardedDDP = None -if is_sagemaker_distributed_available(): +if is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as dist from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP else: import torch.distributed as dist +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat + if is_training_run_on_sagemaker(): logging.add_handler(StreamHandler(sys.stdout)) @@ -522,7 +528,10 @@ class Trainer: else: if self.args.world_size <= 1: return RandomSampler(self.train_dataset) - elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last: + elif ( + self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] + and not self.args.dataloader_drop_last + ): # Use a loop for TPUs when drop_last is False to have all batches have the same size. return DistributedSamplerWithLoop( self.train_dataset, @@ -561,6 +570,13 @@ class Trainer: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: if is_torch_tpu_available(): return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) elif self.args.local_rank != -1: return SequentialDistributedSampler(eval_dataset) else: @@ -674,6 +690,9 @@ class Trainer: else: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + def create_scheduler(self, num_training_steps: int): """ Setup the scheduler. The optimizer of the trainer must have been set up before this method is called. @@ -775,6 +794,12 @@ class Trainer: return model def _wrap_model(self, model, training=True): + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + # already initialized its own DDP and AMP if self.deepspeed: return self.deepspeed @@ -815,7 +840,7 @@ class Trainer: cpu_offload=cpu_offload, ).to(self.args.device) - elif is_sagemaker_distributed_available(): + elif is_sagemaker_dp_enabled(): model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) elif self.args.local_rank != -1: if self.args.ddp_find_unused_parameters is not None: @@ -1280,6 +1305,15 @@ class Trainer: with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + # Consolidate the state dict on all processed of dp_rank 0 + opt_state_dict = self.optimizer.state_dict() + # Save it and the scheduler on the main process + if self.is_world_process_zero(): + torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) elif self.is_world_process_zero() and not self.deepspeed: # deepspeed.save_checkpoint above saves model/optim/sched torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -1337,8 +1371,9 @@ class Trainer: self.optimizer.load_state_dict(optimizer_state) self.lr_scheduler.load_state_dict(lr_scheduler_state) else: + map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device) + torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location) ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt"))) @@ -1478,6 +1513,10 @@ class Trainer: model.train() inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + if self.use_amp: with autocast(): loss = self.compute_loss(model, inputs) @@ -1535,6 +1574,8 @@ class Trainer: """ if is_torch_tpu_available(): return xm.is_master_ordinal(local=True) + elif is_sagemaker_mp_enabled(): + return smp.local_rank() == 0 else: return self.args.local_rank in [-1, 0] @@ -1545,8 +1586,10 @@ class Trainer: """ if is_torch_tpu_available(): return xm.is_master_ordinal(local=False) + elif is_sagemaker_mp_enabled(): + return smp.rank() == 0 else: - return self.args.local_rank == -1 or dist.get_rank() == 0 + return self.args.process_index == 0 def save_model(self, output_dir: Optional[str] = None): """ @@ -1556,6 +1599,11 @@ class Trainer: """ if is_torch_tpu_available(): self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + state_dict = self.model_wrapped.state_dict() + if self.is_world_process_zero(): + self._save(output_dir, state_dict=state_dict) elif ( ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp ): @@ -1905,6 +1953,8 @@ class Trainer: return if is_torch_tpu_available(): tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) elif self.args.local_rank != -1: tensors = distributed_concat(tensors) @@ -1957,27 +2007,47 @@ class Trainer: labels = None with torch.no_grad(): - if has_labels: - loss, outputs = self.compute_loss(model, inputs, return_outputs=True) - loss = loss.mean().detach() - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) else: - logits = outputs[1:] + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) else: - loss = None - if self.use_amp: - with autocast(): + if has_labels: + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: outputs = model(**inputs) - else: - outputs = model(**inputs) - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) - else: - logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index - 1] + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5f2bf82421..b9744f81bd 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -32,11 +32,11 @@ from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler -from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available +from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available from .utils import logging -if is_sagemaker_distributed_available(): +if is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as dist else: import torch.distributed as dist @@ -805,3 +805,40 @@ def get_parameter_names(model, forbidden_layer_types): # Add model specific parameters (defined with nn.Parameter) since they are not in any child. result += list(model._parameters.keys()) return result + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + @smp.step() + def smp_forward_backward(model, inputs, gradient_accumulation_steps=1): + outputs = model(**inputs) + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + loss /= gradient_accumulation_steps + model.backward(loss) + return loss + + @smp.step() + def smp_forward_only(model, inputs): + return model(**inputs) + + def smp_gather(tensor): + if isinstance(tensor, (list, tuple)): + return type(tensor)(smp_gather(t) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: smp_gather(v) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP) + return torch.cat([t.cpu() for t in all_tensors], dim=0) + + def smp_nested_concat(tensor): + if isinstance(tensor, (list, tuple)): + return type(tensor)(smp_nested_concat(t) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()}) + # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` + # which is also the name of the decorator so Python is confused. + return tensor.concat().detach().cpu() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 2108d3d3bc..71df8bc8de 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -31,7 +31,7 @@ import numpy as np from .file_utils import ( ExplicitEnum, is_psutil_available, - is_sagemaker_distributed_available, + is_sagemaker_dp_enabled, is_tf_available, is_torch_available, is_torch_cuda_available, @@ -214,7 +214,7 @@ def total_processes_number(local_rank): import torch_xla.core.xla_model as xm return xm.xrt_world_size() - elif is_sagemaker_distributed_available(): + elif is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as dist return dist.get_world_size() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 65431cb542..3a870ee81c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -21,7 +21,8 @@ from typing import Any, Dict, List, Optional from .file_utils import ( cached_property, - is_sagemaker_distributed_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, is_torch_available, is_torch_tpu_available, torch_required, @@ -36,9 +37,14 @@ if is_torch_available(): if is_torch_tpu_available(): import torch_xla.core.xla_model as xm -if is_sagemaker_distributed_available(): +if is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as sm_dist +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + smp.init() + logger = logging.get_logger(__name__) @@ -519,6 +525,10 @@ class TrainingArguments: default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} ) _n_gpu: int = field(init=False, repr=False, default=-1) + mp_parameters: str = field( + default="", + metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"}, + ) def __post_init__(self): # expand paths, if not os.makedirs("~/bar") will make directory @@ -646,7 +656,11 @@ class TrainingArguments: elif is_torch_tpu_available(): device = xm.xla_device() self._n_gpu = 0 - elif is_sagemaker_distributed_available(): + elif is_sagemaker_mp_enabled(): + local_rank = smp.local_rank() + device = torch.device("cuda", local_rank) + self._n_gpu = 1 + elif is_sagemaker_dp_enabled(): sm_dist.init_process_group() self.local_rank = sm_dist.get_local_rank() device = torch.device("cuda", self.local_rank) @@ -730,8 +744,10 @@ class TrainingArguments: """ if is_torch_tpu_available(): return ParallelMode.TPU - elif is_sagemaker_distributed_available(): - return ParallelMode.SAGEMAKER_DISTRIBUTED + elif is_sagemaker_mp_enabled(): + return ParallelMode.SAGEMAKER_MODEL_PARALLEL + elif is_sagemaker_dp_enabled(): + return ParallelMode.SAGEMAKER_DATA_PARALLEL elif self.local_rank != -1: return ParallelMode.DISTRIBUTED elif self.n_gpu > 1: @@ -747,7 +763,9 @@ class TrainingArguments: """ if is_torch_tpu_available(): return xm.xrt_world_size() - elif is_sagemaker_distributed_available(): + elif is_sagemaker_mp_enabled(): + return smp.dp_size() + elif is_sagemaker_dp_enabled(): return sm_dist.get_world_size() elif self.local_rank != -1: return torch.distributed.get_world_size() @@ -761,7 +779,9 @@ class TrainingArguments: """ if is_torch_tpu_available(): return xm.get_ordinal() - elif is_sagemaker_distributed_available(): + elif is_sagemaker_mp_enabled(): + return smp.dp_rank() + elif is_sagemaker_dp_enabled(): return sm_dist.get_rank() elif self.local_rank != -1: return torch.distributed.get_rank() @@ -772,14 +792,14 @@ class TrainingArguments: """ Can be subclassed and overridden for some specific integrations. """ - return True + return not is_sagemaker_mp_enabled() @property def _no_sync_in_gradient_accumulation(self): """ Whether or not to use no_sync for the gradients when doing gradient accumulation. """ - return not self.deepspeed + return not (self.deepspeed or is_sagemaker_mp_enabled()) def to_dict(self): """ @@ -817,5 +837,6 @@ class ParallelMode(Enum): NOT_PARALLEL = "not_parallel" NOT_DISTRIBUTED = "not_distributed" DISTRIBUTED = "distributed" - SAGEMAKER_DISTRIBUTED = "sm_distributed" + SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel" + SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel" TPU = "tpu" diff --git a/tests/sagemaker/scripts/tensorflow/run_tf_dist.py b/tests/sagemaker/scripts/tensorflow/run_tf_dist.py index 0c1838ce9a..4ff709d037 100644 --- a/tests/sagemaker/scripts/tensorflow/run_tf_dist.py +++ b/tests/sagemaker/scripts/tensorflow/run_tf_dist.py @@ -9,10 +9,10 @@ from datasets import load_dataset from tqdm import tqdm from transformers import AutoTokenizer, TFAutoModelForSequenceClassification -from transformers.file_utils import is_sagemaker_distributed_available +from transformers.file_utils import is_sagemaker_dp_enabled -if os.environ.get("SDP_ENABLED") or is_sagemaker_distributed_available(): +if os.environ.get("SDP_ENABLED") or is_sagemaker_dp_enabled(): SDP_ENABLED = True os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge" import smdistributed.dataparallel.tensorflow as sdp