Fix sagemaker DP/MP (#23681)
* Check for use_sagemaker_dp * Add a check for is_sagemaker_mp when setting _n_gpu again. Should be last broken thing * Try explicit check? * Quality
This commit is contained in:
@@ -1629,6 +1629,9 @@ class TrainingArguments:
|
||||
device = torch.device("cuda", local_rank)
|
||||
self._n_gpu = 1
|
||||
torch.cuda.set_device(device)
|
||||
elif is_sagemaker_dp_enabled():
|
||||
self.distributed_state = PartialState(_use_sagemaker_dp=True)
|
||||
self._n_gpu = 1
|
||||
elif self.deepspeed:
|
||||
# Need to do similar for Accelerator init
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
@@ -1653,8 +1656,9 @@ class TrainingArguments:
|
||||
if is_torch_tpu_available():
|
||||
device = self.distributed_state.device
|
||||
self._n_gpu = 0
|
||||
elif is_sagemaker_dp_enabled():
|
||||
self._n_gpu = 1
|
||||
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
|
||||
# Already set _n_gpu
|
||||
pass
|
||||
elif self.distributed_state.distributed_type == DistributedType.NO:
|
||||
if self.use_mps_device:
|
||||
if not torch.backends.mps.is_available():
|
||||
|
||||
Reference in New Issue
Block a user