Add support for bitsandbytes (#15622)
* Add initial BNB integration * fixup! Add initial BNB integration * Add bnb test decorator * Update Adamw8bit option name * Use the full bnb package name * Overide bnb for all embedding layers * Fix package name * Formatting * Remove unnecessary import * Update src/transformers/trainer.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename AdamwBNB optimizer option * Add training test checking that bnb memory utilization is lower * fix merge * fix merge; fix + extend new test * cleanup * expand bnb * move all require_* candidates to testing_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
committed by
GitHub
parent
e6d23a4b9b
commit
3104036e7f
@@ -31,8 +31,16 @@ from unittest import mock
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from .deepspeed import is_deepspeed_available
|
||||
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
|
||||
from .integrations import (
|
||||
is_fairscale_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_sigopt_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from .utils import (
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_detectron2_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
@@ -638,6 +646,36 @@ def require_deepspeed(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_fairscale(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires fairscale
|
||||
"""
|
||||
if not is_fairscale_available():
|
||||
return unittest.skip("test requires fairscale")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_apex(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires apex
|
||||
"""
|
||||
if not is_apex_available():
|
||||
return unittest.skip("test requires apex")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator for bits and bytes (bnb) dependency
|
||||
"""
|
||||
if not is_bitsandbytes_available():
|
||||
return unittest.skip("test requires bnb")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_phonemizer(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires phonemizer
|
||||
|
||||
Reference in New Issue
Block a user