[setup] make fairscale and deepspeed setup extras (#11151)
* make fairscale and deepspeed setup extras * fix default * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * no reason not to ask for the good version * update the CIs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
7
.github/workflows/self-scheduled.yml
vendored
7
.github/workflows/self-scheduled.yml
vendored
@@ -33,8 +33,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
apt -y update && apt install -y libsndfile1-dev
|
apt -y update && apt install -y libsndfile1-dev
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
|
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed]
|
||||||
pip install deepspeed
|
|
||||||
|
|
||||||
- name: Are GPUs recognized by our DL frameworks
|
- name: Are GPUs recognized by our DL frameworks
|
||||||
run: |
|
run: |
|
||||||
@@ -156,9 +155,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
apt -y update && apt install -y libsndfile1-dev
|
apt -y update && apt install -y libsndfile1-dev
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
|
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,deepspeed,fairscale]
|
||||||
pip install fairscale
|
|
||||||
pip install deepspeed
|
|
||||||
|
|
||||||
- name: Are GPUs recognized by our DL frameworks
|
- name: Are GPUs recognized by our DL frameworks
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -274,6 +274,14 @@ Install the library via pypi:
|
|||||||
|
|
||||||
pip install fairscale
|
pip install fairscale
|
||||||
|
|
||||||
|
or via ``transformers``' ``extras``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install transformers[fairscale]
|
||||||
|
|
||||||
|
(will become available starting from ``transformers==4.6.0``)
|
||||||
|
|
||||||
or find more details on `the FairScale's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
|
or find more details on `the FairScale's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
|
||||||
|
|
||||||
If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`.
|
If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`.
|
||||||
@@ -419,6 +427,14 @@ Install the library via pypi:
|
|||||||
|
|
||||||
pip install deepspeed
|
pip install deepspeed
|
||||||
|
|
||||||
|
or via ``transformers``' ``extras``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install transformers[deepspeed]
|
||||||
|
|
||||||
|
(will become available starting from ``transformers==4.6.0``)
|
||||||
|
|
||||||
or find more details on `the DeepSpeed's GitHub page <https://github.com/microsoft/deepspeed#installation>`__ and
|
or find more details on `the DeepSpeed's GitHub page <https://github.com/microsoft/deepspeed#installation>`__ and
|
||||||
`advanced install <https://www.deepspeed.ai/tutorials/advanced-install/>`__.
|
`advanced install <https://www.deepspeed.ai/tutorials/advanced-install/>`__.
|
||||||
|
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -90,7 +90,9 @@ _deps = [
|
|||||||
"cookiecutter==1.7.2",
|
"cookiecutter==1.7.2",
|
||||||
"dataclasses",
|
"dataclasses",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"deepspeed>0.3.13",
|
||||||
"docutils==0.16.0",
|
"docutils==0.16.0",
|
||||||
|
"fairscale>0.3",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"filelock",
|
"filelock",
|
||||||
@@ -233,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
|
|||||||
extras["modelcreation"] = deps_list("cookiecutter")
|
extras["modelcreation"] = deps_list("cookiecutter")
|
||||||
|
|
||||||
extras["sagemaker"] = deps_list("sagemaker")
|
extras["sagemaker"] = deps_list("sagemaker")
|
||||||
|
extras["deepspeed"] = deps_list("deepspeed")
|
||||||
|
extras["fairscale"] = deps_list("fairscale")
|
||||||
|
|
||||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||||
extras["speech"] = deps_list("soundfile", "torchaudio")
|
extras["speech"] = deps_list("soundfile", "torchaudio")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from .dependency_versions_table import deps
|
from .dependency_versions_table import deps
|
||||||
from .utils.versions import require_version_core
|
from .utils.versions import require_version, require_version_core
|
||||||
|
|
||||||
|
|
||||||
# define which module versions we always want to check at run time
|
# define which module versions we always want to check at run time
|
||||||
@@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
|
|||||||
require_version_core(deps[pkg])
|
require_version_core(deps[pkg])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
||||||
|
|
||||||
|
|
||||||
|
def dep_version_check(pkg, hint=None):
|
||||||
|
require_version(deps[pkg], hint)
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ deps = {
|
|||||||
"cookiecutter": "cookiecutter==1.7.2",
|
"cookiecutter": "cookiecutter==1.7.2",
|
||||||
"dataclasses": "dataclasses",
|
"dataclasses": "dataclasses",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
|
"deepspeed": "deepspeed>0.3.13",
|
||||||
"docutils": "docutils==0.16.0",
|
"docutils": "docutils==0.16.0",
|
||||||
|
"fairscale": "fairscale>0.3",
|
||||||
"faiss-cpu": "faiss-cpu",
|
"faiss-cpu": "faiss-cpu",
|
||||||
"fastapi": "fastapi",
|
"fastapi": "fastapi",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ import tempfile
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .dependency_versions_check import dep_version_check
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
from .utils.versions import require_version
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
|
|||||||
|
|
||||||
If it's already a dict, return a copy of it, so that we can freely modify it.
|
If it's already a dict, return a copy of it, so that we can freely modify it.
|
||||||
"""
|
"""
|
||||||
require_version("deepspeed>0.3.13")
|
dep_version_check("deepspeed")
|
||||||
|
|
||||||
if isinstance(ds_config, dict):
|
if isinstance(ds_config, dict):
|
||||||
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
|
from .dependency_versions_check import dep_version_check
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
@@ -139,17 +140,14 @@ if is_torch_tpu_available():
|
|||||||
import torch_xla.distributed.parallel_loader as pl
|
import torch_xla.distributed.parallel_loader as pl
|
||||||
|
|
||||||
if is_fairscale_available():
|
if is_fairscale_available():
|
||||||
|
dep_version_check("fairscale")
|
||||||
import fairscale
|
import fairscale
|
||||||
|
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
|
||||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||||
|
from fairscale.nn.wrap import auto_wrap
|
||||||
from fairscale.optim import OSS
|
from fairscale.optim import OSS
|
||||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||||
|
|
||||||
if version.parse(fairscale.__version__) >= version.parse("0.3"):
|
|
||||||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
|
|
||||||
from fairscale.nn.wrap import auto_wrap
|
|
||||||
else:
|
|
||||||
FullyShardedDDP = None
|
|
||||||
|
|
||||||
if is_sagemaker_dp_enabled():
|
if is_sagemaker_dp_enabled():
|
||||||
import smdistributed.dataparallel.torch.distributed as dist
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
||||||
|
|||||||
@@ -60,6 +60,12 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
|||||||
Args:
|
Args:
|
||||||
requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
|
requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
|
||||||
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
|
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
require_version("pandas>1.1.2")
|
||||||
|
require_version("numpy>1.18.5", "this is important to have for whatever reason")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hint = f"\n{hint}" if hint is not None else ""
|
hint = f"\n{hint}" if hint is not None else ""
|
||||||
|
|||||||
Reference in New Issue
Block a user