Raise err if minimum Accelerate version isn't available (#22841)

* Add warning about accelerate

* Version block Accelerate

* Include parse

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Check partial state

* Update param

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Zachary Mueller
2023-04-18 14:25:02 -04:00
committed by GitHub
parent 5f09219400
commit 5bb4ec6233
2 changed files with 13 additions and 2 deletions

View File

@@ -1531,6 +1531,10 @@ class TrainingArguments:
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True):
raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate`: Run `pip install --upgrade accelerate`"
)
if self.no_cuda:
self.distributed_state = PartialState(cpu=True)
device = self.distributed_state.device