Add support for bitsandbytes (#15622)
* Add initial BNB integration * fixup! Add initial BNB integration * Add bnb test decorator * Update Adamw8bit option name * Use the full bnb package name * Overide bnb for all embedding layers * Fix package name * Formatting * Remove unnecessary import * Update src/transformers/trainer.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Rename AdamwBNB optimizer option * Add training test checking that bnb memory utilization is lower * fix merge * fix merge; fix + extend new test * cleanup * expand bnb * move all require_* candidates to testing_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
committed by
GitHub
parent
e6d23a4b9b
commit
3104036e7f
@@ -65,7 +65,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import WEIGHTS_NAME, is_apex_available
|
||||
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@@ -1870,6 +1870,7 @@ if is_torch_available():
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
if is_apex_available():
|
||||
import apex
|
||||
|
||||
@@ -1881,6 +1882,17 @@ if is_torch_available():
|
||||
)
|
||||
)
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
OptimizerNames.ADAMW_BNB,
|
||||
bnb.optim.Adam8bit,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
@@ -1905,8 +1917,8 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
|
||||
def test_fused_adam(self):
|
||||
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a
|
||||
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
|
||||
# class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
|
||||
# the test to run without requiring an apex installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
@@ -1930,6 +1942,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_bnb_adam8bit(self):
|
||||
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
|
||||
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
|
||||
# the test to run without requiring a bnb installation.
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
OptimizerNames.ADAMW_BNB,
|
||||
default_adam_kwargs,
|
||||
mock.optim.Adam8bit,
|
||||
)
|
||||
|
||||
def test_bnb_adam8bit_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
|
||||
|
||||
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||
# bnb will fail even if bnb is installed.
|
||||
with patch.dict("sys.modules", {"bnb.optim": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_wandb
|
||||
|
||||
Reference in New Issue
Block a user