diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d213b08d6d..441a546e05 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1667,12 +1667,12 @@ class TrainingArguments: def _setup_devices(self) -> "torch.device": requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") - AcceleratorState._reset_state() - PartialState._reset_state() - if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True): - raise ImportError( - "Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" - ) + if not is_sagemaker_mp_enabled(): + if not is_accelerate_available(check_partial_state=True): + raise ImportError( + "Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" + ) + AcceleratorState._reset_state(reset_partial_state=True) self.distributed_state = None if self.no_cuda: self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)