From 31ea547b7a4e9700dbe299f84b55f7e12c58ec13 Mon Sep 17 00:00:00 2001 From: youngrok cha Date: Tue, 22 Apr 2025 18:17:29 +0900 Subject: [PATCH] [fix] make legacy bnb code work (#37331) * [fix] make legacy bnb code work * [fix] use get with default instead of getter * add test for bnb 8bit optim skip embed * [fix] style * add require annotation of bnb --------- Co-authored-by: jaycha Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 2 +- tests/trainer/test_trainer.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6a7b8f1034..d644534020 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1247,7 +1247,7 @@ class Trainer: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - if optimizer_cls.__name__ == "Adam8bit": + if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8: import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e9f11bf729..0eb2ba6989 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5962,3 +5962,22 @@ class OptimizerAndModelInspectionTest(unittest.TestCase): param = next(model.parameters()) group = trainer.get_optimizer_group(param) self.assertIn(param, group["params"]) + + @require_bitsandbytes + def test_bnb_8bit_optimizer_skip_embedding(self): + model = BasicTextGenerationModel(8, 4) + with tempfile.TemporaryDirectory() as tmp_dir: + for name_optim in ["rmsprop_bnb_8bit", "adamw_8bit"]: + args = TrainingArguments( + output_dir=tmp_dir, + report_to="none", + optim=name_optim, + ) + trainer = Trainer(model=model, args=args) + optimizer = trainer.create_optimizer() + modules = optimizer.mng.module_weight_config_triple + self.assertNotEqual(len(modules), 0) + module, name, config = modules[0] + self.assertIsInstance(module, torch.nn.Embedding) + self.assertEqual(name, "weight") + self.assertDictEqual(config, {"optim_bits": 32})