From 03462875cc2d6506eb66a74de7d19b93ce968596 Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Mon, 17 Apr 2023 15:09:45 -0400 Subject: [PATCH] Introduce `PartialState` as the device handler in the `Trainer` (#22752) * Use accelerate for device management * Add accelerate to setup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- setup.py | 2 +- src/transformers/trainer.py | 25 ++-- src/transformers/training_args.py | 153 ++++++---------------- tests/trainer/test_trainer_distributed.py | 16 +-- 4 files changed, 56 insertions(+), 140 deletions(-) diff --git a/setup.py b/setup.py index 278825a5c6..eb954ad71d 100644 --- a/setup.py +++ b/setup.py @@ -260,7 +260,7 @@ extras["sklearn"] = deps_list("scikit-learn") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") -extras["torch"] = deps_list("torch") +extras["torch"] = deps_list("torch", "accelerate") extras["accelerate"] = deps_list("accelerate") if os.name == "nt": # windows diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cf71499b00..3b4da72838 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -416,8 +416,7 @@ class Trainer: raise ValueError( "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." ) - - if args.local_rank == -1: + if args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using sharded DDP only works in distributed training.") elif not is_fairscale_available(): raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") @@ -439,7 +438,7 @@ class Trainer: raise ValueError( "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." ) - if not args.fsdp_config["xla"] and args.local_rank == -1: + if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using fsdp only works in distributed training.") # dep_version_check("torch>=1.12.0") @@ -551,7 +550,7 @@ class Trainer: # In case of pull, we need to make sure every process has the latest. if is_torch_tpu_available(): xm.rendezvous("init git repo") - elif args.local_rank != -1: + elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() if self.args.should_save: @@ -929,7 +928,7 @@ class Trainer: rank=smp.dp_rank(), batch_size=self.args.per_device_eval_batch_size, ) - elif self.args.local_rank != -1: + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: return SequentialDistributedSampler(eval_dataset) else: return SequentialSampler(eval_dataset) @@ -1551,7 +1550,7 @@ class Trainer: model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) - elif self.args.local_rank != -1: + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters @@ -1919,7 +1918,7 @@ class Trainer: if ( (total_batched_samples % args.gradient_accumulation_steps != 0) - and args.local_rank != -1 + and args.parallel_mode == ParallelMode.DISTRIBUTED and args._no_sync_in_gradient_accumulation ): # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. @@ -2041,7 +2040,7 @@ class Trainer: # Wait for everyone to get here so we are sur the model has been saved by process 0. if is_torch_tpu_available(): xm.rendezvous("load_best_model_at_end") - elif args.local_rank != -1: + elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() @@ -2319,7 +2318,7 @@ class Trainer: np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) if torch.cuda.is_available(): - if self.args.local_rank != -1: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) else: try: @@ -2413,7 +2412,7 @@ class Trainer: "cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): - if self.args.local_rank == -1: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.cuda.random.get_rng_state_all() else: @@ -2895,7 +2894,7 @@ class Trainer: def store_flos(self): # Storing the number of floating-point operations that went into the model - if self.args.local_rank != -1: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: self.state.total_flos += ( distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() ) @@ -3310,7 +3309,7 @@ class Trainer: tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) - elif self.args.local_rank != -1: + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: tensors = distributed_concat(tensors) return tensors @@ -3834,7 +3833,7 @@ class Trainer: tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) - elif self.args.local_rank != -1: + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: tensors = distributed_concat(tensors) return nested_numpify(tensors) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 30c2461ffa..6af20f4549 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -38,10 +38,8 @@ from .trainer_utils import ( from .utils import ( ExplicitEnum, cached_property, - ccl_version, get_full_repo_name, is_accelerate_available, - is_psutil_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, @@ -65,6 +63,10 @@ if is_torch_available(): import torch import torch.distributed as dist +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import DistributedType + if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm @@ -1122,12 +1124,6 @@ class TrainingArguments: ) def __post_init__(self): - # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). - # This needs to happen before any call to self.device or self.n_gpu. - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != self.local_rank: - self.local_rank = env_local_rank - # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home # see https://github.com/huggingface/transformers/issues/10628 @@ -1535,104 +1531,40 @@ class TrainingArguments: def _setup_devices(self) -> "torch.device": requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") - if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1: - logger.warning( - "torch.distributed process group is initialized, but local_rank == -1. " - "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" - ) if self.no_cuda: - device = torch.device("cpu") - self._n_gpu = 0 - self.local_rank = get_int_from_env( - ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], - self.local_rank, - ) - if self.local_rank != -1 and not torch.distributed.is_initialized(): - # Initializes distributed backend for cpu - if self.xpu_backend not in ("mpi", "ccl", "gloo"): - raise ValueError( - "CPU distributed training backend is not properly set. " - "Please set '--xpu_backend' to either 'mpi' or 'ccl' or 'gloo'." - ) - if self.xpu_backend == "ccl": - requires_backends(self, "oneccl_bind_pt") - if ccl_version >= "1.12": - import oneccl_bindings_for_pytorch # noqa: F401 - else: - import torch_ccl # noqa: F401 - if int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1: - raise ValueError( - "CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. " - "Please use like 'export CCL_WORKER_COUNT = 1' to set." - ) - - # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH - rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0) - size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1) - local_size = get_int_from_env( - ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 - ) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(size) - os.environ["LOCAL_RANK"] = str(self.local_rank) - if not os.environ.get("MASTER_PORT", None): - os.environ["MASTER_PORT"] = "29500" - if not os.environ.get("MASTER_ADDR", None): - if local_size != size or self.xpu_backend != "mpi": - raise ValueError( - "Looks like distributed multinode run but MASTER_ADDR env not set, " - "please try exporting rank 0's hostname as MASTER_ADDR" - ) - if ( - torch.get_num_threads() == 1 - and get_int_from_env(["OMP_NUM_THREADS", "MKL_NUM_THREADS"], 0) == 0 - and is_psutil_available() - ): - import psutil - - num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / local_size) - if num_cpu_threads_per_process == 0: - num_cpu_threads_per_process = 1 - torch.set_num_threads(num_cpu_threads_per_process) - logger.info( - f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob" - " performance." - ) - torch.distributed.init_process_group( - backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta - ) - elif is_torch_tpu_available(): - device = xm.xla_device() + self.distributed_state = PartialState(cpu=True) + device = self.distributed_state.device self._n_gpu = 0 + self.local_rank = self.distributed_state.local_process_index elif is_sagemaker_mp_enabled(): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) self._n_gpu = 1 - elif is_sagemaker_dp_enabled(): - import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 - - dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta) - self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) - device = torch.device("cuda", self.local_rank) - self._n_gpu = 1 + torch.cuda.set_device(device) elif self.deepspeed: - # deepspeed inits torch.distributed internally - from .deepspeed import is_deepspeed_available - - if not is_deepspeed_available(): - raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") - import deepspeed - - deepspeed.init_distributed(timeout=timedelta(seconds=self.ddp_timeout)) - - # workaround for setups like notebooks where the launcher can't be used, - # but deepspeed requires a dist env. - # env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed - self.local_rank = int(os.environ.get("LOCAL_RANK", "-1")) - - device = torch.device("cuda", self.local_rank) + self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) self._n_gpu = 1 - elif self.local_rank == -1: + device = self.distributed_state.device + else: + self.distributed_state = PartialState(backend=self.xpu_backend) + device = self.distributed_state.device + self._n_gpu = 1 + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and self.distributed_state.distributed_type != DistributedType.NO + ): + logger.warning( + "torch.distributed process group is initialized, but parallel_mode == ParallelMode.DISTRIBUTED. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) + + if is_torch_tpu_available(): + device = self.distributed_state.device + self._n_gpu = 0 + elif is_sagemaker_dp_enabled(): + self._n_gpu = 1 + elif self.distributed_state.distributed_type == DistributedType.NO: if self.use_mps_device: if not torch.backends.mps.is_available(): if not torch.backends.mps.is_built(): @@ -1665,24 +1597,13 @@ class TrainingArguments: # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 + # device = self.distributed_state.device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # the default value. self._n_gpu = torch.cuda.device_count() - else: - # Here, we'll use torch.distributed. - # Initializes the distributed backend which will take care of synchronizing nodes/GPUs - if not torch.distributed.is_initialized(): - if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"): - torch.distributed.init_process_group(backend=self.xpu_backend, timeout=self.ddp_timeout_delta) - else: - torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) - device = torch.device("cuda", self.local_rank) - self._n_gpu = 1 - - if device.type == "cuda": - torch.cuda.set_device(device) - + if device.type == "cuda": + torch.cuda.set_device(device) return device @property @@ -1725,7 +1646,7 @@ class TrainingArguments: return ParallelMode.SAGEMAKER_MODEL_PARALLEL elif is_sagemaker_dp_enabled(): return ParallelMode.SAGEMAKER_DATA_PARALLEL - elif self.local_rank != -1: + elif hasattr(self, "distributed_state") and (self.distributed_state.distributed_type != DistributedType.NO): return ParallelMode.DISTRIBUTED elif self.n_gpu > 1: return ParallelMode.NOT_DISTRIBUTED @@ -1745,7 +1666,7 @@ class TrainingArguments: return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() elif is_sagemaker_dp_enabled(): return dist.get_world_size() - elif self.local_rank != -1: + elif self.parallel_mode == ParallelMode.DISTRIBUTED: return torch.distributed.get_world_size() return 1 @@ -1761,7 +1682,7 @@ class TrainingArguments: return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank() elif is_sagemaker_dp_enabled(): return dist.get_rank() - elif self.local_rank != -1: + elif self.parallel_mode == ParallelMode.DISTRIBUTED: return torch.distributed.get_rank() return 0 @@ -1777,7 +1698,7 @@ class TrainingArguments: return smp.local_rank() elif is_sagemaker_dp_enabled(): return dist.get_rank() - elif self.local_rank != -1: + elif self.parallel_mode == ParallelMode.DISTRIBUTED: return self.local_rank return 0 diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 97bca4f9d3..4078885a5d 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from typing import Dict from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available @@ -23,6 +22,7 @@ from transformers.testing_utils import ( require_torch_multi_gpu, require_torch_neuroncore, ) +from transformers.training_args import ParallelMode from transformers.utils import logging @@ -66,15 +66,13 @@ if is_torch_available(): class TestTrainerDistributedNeuronCore(TestCasePlus): @require_torch_neuroncore def test_trainer(self): - distributed_args = f""" - -m torch.distributed.run - --nproc_per_node=2 + distributed_args = f"""--nproc_per_node=2 --master_port={get_torch_dist_unique_port()} {self.test_file_dir}/test_trainer_distributed.py """.split() output_dir = self.get_auto_remove_tmp_dir() args = f"--output_dir {output_dir}".split() - cmd = [sys.executable] + distributed_args + args + cmd = ["torchrun"] + distributed_args + args execute_subprocess_async(cmd, env=self.get_env()) # successful return here == success - any errors would have caused an error in the sub-call @@ -82,15 +80,13 @@ class TestTrainerDistributedNeuronCore(TestCasePlus): class TestTrainerDistributed(TestCasePlus): @require_torch_multi_gpu def test_trainer(self): - distributed_args = f""" - -m torch.distributed.run - --nproc_per_node={torch.cuda.device_count()} + distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} --master_port={get_torch_dist_unique_port()} {self.test_file_dir}/test_trainer_distributed.py """.split() output_dir = self.get_auto_remove_tmp_dir() args = f"--output_dir {output_dir}".split() - cmd = [sys.executable] + distributed_args + args + cmd = ["torchrun"] + distributed_args + args execute_subprocess_async(cmd, env=self.get_env()) # successful return here == success - any errors would have caused an error in the sub-call @@ -105,7 +101,7 @@ if __name__ == "__main__": logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " - f"distributed training: {training_args.local_rank != -1}" + f"distributed training: {training_args.parallel_mode != ParallelMode.NOT_DISTRIBUTED}" ) # Essentially, what we want to verify in the distributed case is that we get all samples back,