Add support for bitsandbytes (#15622)

* Add initial BNB integration

* fixup! Add initial BNB integration

* Add bnb test decorator

* Update Adamw8bit option name

* Use the full bnb package name

* Overide bnb for all embedding layers

* Fix package name

* Formatting

* Remove unnecessary import

* Update src/transformers/trainer.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename AdamwBNB optimizer option

* Add training test checking that bnb memory utilization is lower

* fix merge

* fix merge; fix + extend new test

* cleanup

* expand bnb

* move all require_* candidates to testing_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Manuel R. Ciosici
2022-04-19 13:01:29 -07:00
committed by GitHub
parent e6d23a4b9b
commit 3104036e7f
7 changed files with 194 additions and 29 deletions

View File

@@ -31,8 +31,16 @@ from unittest import mock
from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
from .integrations import (
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
)
from .utils import (
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
@@ -638,6 +646,36 @@ def require_deepspeed(test_case):
return test_case
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bitsandbytes_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer