Introduce PartialState as the device handler in the Trainer (#22752)

* Use accelerate for device management

* Add accelerate to setup


Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Zachary Mueller
2023-04-17 15:09:45 -04:00
committed by GitHub
parent 50caa20628
commit 03462875cc
4 changed files with 56 additions and 140 deletions

View File

@@ -260,7 +260,7 @@ extras["sklearn"] = deps_list("scikit-learn")
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
extras["torch"] = deps_list("torch")
extras["torch"] = deps_list("torch", "accelerate")
extras["accelerate"] = deps_list("accelerate")
if os.name == "nt": # windows