Fix sagemaker DP/MP (#23681)
* Check for use_sagemaker_dp * Add a check for is_sagemaker_mp when setting _n_gpu again. Should be last broken thing * Try explicit check? * Quality
This commit is contained in:
@@ -3398,7 +3398,9 @@ class Trainer:
|
|||||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
tensors = smp_gather(tensors)
|
tensors = smp_gather(tensors)
|
||||||
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
|
||||||
|
self.args.distributed_state is None and self.local_rank != -1
|
||||||
|
):
|
||||||
tensors = distributed_concat(tensors)
|
tensors = distributed_concat(tensors)
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
|
|||||||
@@ -1629,6 +1629,9 @@ class TrainingArguments:
|
|||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", local_rank)
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
|
self.distributed_state = PartialState(_use_sagemaker_dp=True)
|
||||||
|
self._n_gpu = 1
|
||||||
elif self.deepspeed:
|
elif self.deepspeed:
|
||||||
# Need to do similar for Accelerator init
|
# Need to do similar for Accelerator init
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
@@ -1653,8 +1656,9 @@ class TrainingArguments:
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
device = self.distributed_state.device
|
device = self.distributed_state.device
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
|
||||||
self._n_gpu = 1
|
# Already set _n_gpu
|
||||||
|
pass
|
||||||
elif self.distributed_state.distributed_type == DistributedType.NO:
|
elif self.distributed_state.distributed_type == DistributedType.NO:
|
||||||
if self.use_mps_device:
|
if self.use_mps_device:
|
||||||
if not torch.backends.mps.is_available():
|
if not torch.backends.mps.is_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user