Migrate HFDeepSpeedConfig from trfrs to accelerate (#17623)
* Migrate HFDeepSpeedConfig from trfrs to accelerate * add `accelerate` to testing dep * addressing comments * addressing comments Using `_shared_state` and avoiding object creation. This is necessary as `notebook_launcher` in `launcers.py` checks `len(AcceleratorState._shared_state)>0` to throw an error. * resolving comments 1. Use simple API from accelerate to manage the deepspeed config integration 2. Update the related documentation * reverting changes and addressing comments * docstring correction * addressing nits * addressing nits * addressing nits 3 * bumping up the accelerate version to 0.10.0 * resolving import * update setup.py to include deepspeed dependencies * Update dependency_versions_table.py * fixing imports * reverting changes to CI dependencies for "run_tests_pipelines_tf*" tests These changes didn't help with resolving the failures and I believe this needs to be addressed in another PR. * removing `accelerate` as hard dependency Resolves issues related to CI Tests * adding `accelerate` as dependency for building docs resolves failure in Build PR Documentation test * adding `accelerate` as dependency in "dev" to resolve doc build issue * resolving comments 1. adding `accelerate` to extras["all"] 2. Including check for accelerate too before import HFDeepSpeedConfig from there Co-Authored-By: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * resolving comments Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e44a569fef
commit
21a772426d
36
setup.py
36
setup.py
@@ -97,7 +97,7 @@ if stale_egg_info.exists():
|
||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"accelerate>=0.9.0",
|
||||
"accelerate>=0.10.0",
|
||||
"black~=22.0,>=22.3",
|
||||
"codecarbon==1.2.0",
|
||||
"cookiecutter==1.7.3",
|
||||
@@ -242,6 +242,7 @@ extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
|
||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")
|
||||
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["accelerate"] = deps_list("accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["retrieval"] = deps_list("datasets") # faiss is not supported on windows
|
||||
@@ -257,7 +258,7 @@ extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxrunt
|
||||
extras["modelcreation"] = deps_list("cookiecutter")
|
||||
|
||||
extras["sagemaker"] = deps_list("sagemaker")
|
||||
extras["deepspeed"] = deps_list("deepspeed")
|
||||
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
|
||||
extras["fairscale"] = deps_list("fairscale")
|
||||
extras["optuna"] = deps_list("optuna")
|
||||
extras["ray"] = deps_list("ray[tune]")
|
||||
@@ -293,9 +294,9 @@ extras["testing"] = (
|
||||
"nltk",
|
||||
"GitPython",
|
||||
"hf-doc-builder",
|
||||
"protobuf", # Can be removed once we can unpin protobuf
|
||||
"protobuf", # Can be removed once we can unpin protobuf
|
||||
"sacremoses",
|
||||
"rjieba"
|
||||
"rjieba",
|
||||
)
|
||||
+ extras["retrieval"]
|
||||
+ extras["modelcreation"]
|
||||
@@ -316,6 +317,7 @@ extras["all"] = (
|
||||
+ extras["integrations"]
|
||||
+ extras["timm"]
|
||||
+ extras["codecarbon"]
|
||||
+ extras["accelerate"]
|
||||
)
|
||||
|
||||
# Might need to add doc-builder and some specific deps in the future
|
||||
@@ -325,8 +327,8 @@ extras["docs_specific"] = ["hf-doc-builder"]
|
||||
extras["docs"] = extras["all"] + extras["docs_specific"]
|
||||
|
||||
extras["dev-torch"] = (
|
||||
extras['testing']
|
||||
+ extras['torch']
|
||||
extras["testing"]
|
||||
+ extras["torch"]
|
||||
+ extras["sentencepiece"]
|
||||
+ extras["tokenizers"]
|
||||
+ extras["torch-speech"]
|
||||
@@ -342,17 +344,17 @@ extras["dev-torch"] = (
|
||||
+ extras["onnxruntime"]
|
||||
)
|
||||
extras["dev-tensorflow"] = (
|
||||
extras['testing']
|
||||
+ extras['tf']
|
||||
+ extras["sentencepiece"]
|
||||
+ extras["tokenizers"]
|
||||
+ extras["vision"]
|
||||
+ extras["quality"]
|
||||
+ extras["docs_specific"]
|
||||
+ extras["sklearn"]
|
||||
+ extras["modelcreation"]
|
||||
+ extras["onnx"]
|
||||
+ extras["tf-speech"]
|
||||
extras["testing"]
|
||||
+ extras["tf"]
|
||||
+ extras["sentencepiece"]
|
||||
+ extras["tokenizers"]
|
||||
+ extras["vision"]
|
||||
+ extras["quality"]
|
||||
+ extras["docs_specific"]
|
||||
+ extras["sklearn"]
|
||||
+ extras["modelcreation"]
|
||||
+ extras["onnx"]
|
||||
+ extras["tf-speech"]
|
||||
)
|
||||
extras["dev"] = (
|
||||
extras["all"]
|
||||
|
||||
Reference in New Issue
Block a user