Add support for bitsandbytes (#15622)

* Add initial BNB integration

* fixup! Add initial BNB integration

* Add bnb test decorator

* Update Adamw8bit option name

* Use the full bnb package name

* Overide bnb for all embedding layers

* Fix package name

* Formatting

* Remove unnecessary import

* Update src/transformers/trainer.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Rename AdamwBNB optimizer option

* Add training test checking that bnb memory utilization is lower

* fix merge

* fix merge; fix + extend new test

* cleanup

* expand bnb

* move all require_* candidates to testing_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Manuel R. Ciosici
2022-04-19 13:01:29 -07:00
committed by GitHub
parent e6d23a4b9b
commit 3104036e7f
7 changed files with 194 additions and 29 deletions

View File

@@ -65,7 +65,7 @@ from transformers.testing_utils import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_NAME, is_apex_available
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer
@@ -1870,6 +1870,7 @@ if is_torch_available():
},
),
]
if is_apex_available():
import apex
@@ -1881,6 +1882,17 @@ if is_torch_available():
)
)
if is_bitsandbytes_available():
import bitsandbytes as bnb
optim_test_params.append(
(
OptimizerNames.ADAMW_BNB,
bnb.optim.Adam8bit,
default_adam_kwargs,
)
)
@require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase):
@@ -1905,8 +1917,8 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation.
mock = Mock()
modules = {
@@ -1930,6 +1942,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_adam8bit(self):
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB,
default_adam_kwargs,
mock.optim.Adam8bit,
)
def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bnb.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
@require_torch
@require_wandb