Fix parallel mode check (#23409)
* Fix sagemaker/distributed state * Fix correctly * Bring back -1 * Bring back local rank for distributed check * better version * Cleanest option
This commit is contained in:
@@ -1613,6 +1613,7 @@ class TrainingArguments:
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||||
)
|
)
|
||||||
|
self.distributed_state = None
|
||||||
if self.no_cuda:
|
if self.no_cuda:
|
||||||
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
|
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
@@ -1636,7 +1637,7 @@ class TrainingArguments:
|
|||||||
if (
|
if (
|
||||||
torch.distributed.is_available()
|
torch.distributed.is_available()
|
||||||
and torch.distributed.is_initialized()
|
and torch.distributed.is_initialized()
|
||||||
and self.distributed_state.distributed_type == DistributedType.NO
|
and self.parallel_mode != ParallelMode.DISTRIBUTED
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
|
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
|
||||||
@@ -1728,7 +1729,9 @@ class TrainingArguments:
|
|||||||
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
return ParallelMode.SAGEMAKER_DATA_PARALLEL
|
return ParallelMode.SAGEMAKER_DATA_PARALLEL
|
||||||
elif hasattr(self, "distributed_state") and self.distributed_state.distributed_type != DistributedType.NO:
|
elif (
|
||||||
|
self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO
|
||||||
|
) or (self.distributed_state is None and self.local_rank != -1):
|
||||||
return ParallelMode.DISTRIBUTED
|
return ParallelMode.DISTRIBUTED
|
||||||
elif self.n_gpu > 1:
|
elif self.n_gpu > 1:
|
||||||
return ParallelMode.NOT_DISTRIBUTED
|
return ParallelMode.NOT_DISTRIBUTED
|
||||||
|
|||||||
Reference in New Issue
Block a user