Introduce PartialState as the device handler in the Trainer (#22752)

* Use accelerate for device management

* Add accelerate to setup


Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Zachary Mueller
2023-04-17 15:09:45 -04:00
committed by GitHub
parent 50caa20628
commit 03462875cc
4 changed files with 56 additions and 140 deletions

View File

@@ -38,10 +38,8 @@ from .trainer_utils import (
from .utils import (
ExplicitEnum,
cached_property,
ccl_version,
get_full_repo_name,
is_accelerate_available,
is_psutil_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
@@ -65,6 +63,10 @@ if is_torch_available():
import torch
import torch.distributed as dist
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import DistributedType
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
@@ -1122,12 +1124,6 @@ class TrainingArguments:
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628
@@ -1535,104 +1531,40 @@ class TrainingArguments:
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
self.local_rank = get_int_from_env(
["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"],
self.local_rank,
)
if self.local_rank != -1 and not torch.distributed.is_initialized():
# Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl", "gloo"):
raise ValueError(
"CPU distributed training backend is not properly set. "
"Please set '--xpu_backend' to either 'mpi' or 'ccl' or 'gloo'."
)
if self.xpu_backend == "ccl":
requires_backends(self, "oneccl_bind_pt")
if ccl_version >= "1.12":
import oneccl_bindings_for_pytorch # noqa: F401
else:
import torch_ccl # noqa: F401
if int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
raise ValueError(
"CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
"Please use like 'export CCL_WORKER_COUNT = 1' to set."
)
# Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1)
local_size = get_int_from_env(
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(size)
os.environ["LOCAL_RANK"] = str(self.local_rank)
if not os.environ.get("MASTER_PORT", None):
os.environ["MASTER_PORT"] = "29500"
if not os.environ.get("MASTER_ADDR", None):
if local_size != size or self.xpu_backend != "mpi":
raise ValueError(
"Looks like distributed multinode run but MASTER_ADDR env not set, "
"please try exporting rank 0's hostname as MASTER_ADDR"
)
if (
torch.get_num_threads() == 1
and get_int_from_env(["OMP_NUM_THREADS", "MKL_NUM_THREADS"], 0) == 0
and is_psutil_available()
):
import psutil
num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
if num_cpu_threads_per_process == 0:
num_cpu_threads_per_process = 1
torch.set_num_threads(num_cpu_threads_per_process)
logger.info(
f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob"
" performance."
)
torch.distributed.init_process_group(
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta
)
elif is_torch_tpu_available():
device = xm.xla_device()
self.distributed_state = PartialState(cpu=True)
device = self.distributed_state.device
self._n_gpu = 0
self.local_rank = self.distributed_state.local_process_index
elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif self.deepspeed:
# deepspeed inits torch.distributed internally
from .deepspeed import is_deepspeed_available
if not is_deepspeed_available():
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed
deepspeed.init_distributed(timeout=timedelta(seconds=self.ddp_timeout))
# workaround for setups like notebooks where the launcher can't be used,
# but deepspeed requires a dist env.
# env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
device = torch.device("cuda", self.local_rank)
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
self._n_gpu = 1
elif self.local_rank == -1:
device = self.distributed_state.device
else:
self.distributed_state = PartialState(backend=self.xpu_backend)
device = self.distributed_state.device
self._n_gpu = 1
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and self.distributed_state.distributed_type != DistributedType.NO
):
logger.warning(
"torch.distributed process group is initialized, but parallel_mode == ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if is_torch_tpu_available():
device = self.distributed_state.device
self._n_gpu = 0
elif is_sagemaker_dp_enabled():
self._n_gpu = 1
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
@@ -1665,24 +1597,13 @@ class TrainingArguments:
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
# device = self.distributed_state.device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"):
torch.distributed.init_process_group(backend=self.xpu_backend, timeout=self.ddp_timeout_delta)
else:
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
if device.type == "cuda":
torch.cuda.set_device(device)
if device.type == "cuda":
torch.cuda.set_device(device)
return device
@property
@@ -1725,7 +1646,7 @@ class TrainingArguments:
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
elif is_sagemaker_dp_enabled():
return ParallelMode.SAGEMAKER_DATA_PARALLEL
elif self.local_rank != -1:
elif hasattr(self, "distributed_state") and (self.distributed_state.distributed_type != DistributedType.NO):
return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1:
return ParallelMode.NOT_DISTRIBUTED
@@ -1745,7 +1666,7 @@ class TrainingArguments:
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled():
return dist.get_world_size()
elif self.local_rank != -1:
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return torch.distributed.get_world_size()
return 1
@@ -1761,7 +1682,7 @@ class TrainingArguments:
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled():
return dist.get_rank()
elif self.local_rank != -1:
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return torch.distributed.get_rank()
return 0
@@ -1777,7 +1698,7 @@ class TrainingArguments:
return smp.local_rank()
elif is_sagemaker_dp_enabled():
return dist.get_rank()
elif self.local_rank != -1:
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
return self.local_rank
return 0