[fix] make legacy bnb code work (#37331)
* [fix] make legacy bnb code work * [fix] use get with default instead of getter * add test for bnb 8bit optim skip embed * [fix] style * add require annotation of bnb --------- Co-authored-by: jaycha <jaycha@ncsoft.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -1247,7 +1247,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|||||||
@@ -5962,3 +5962,22 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
|
|||||||
param = next(model.parameters())
|
param = next(model.parameters())
|
||||||
group = trainer.get_optimizer_group(param)
|
group = trainer.get_optimizer_group(param)
|
||||||
self.assertIn(param, group["params"])
|
self.assertIn(param, group["params"])
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_bnb_8bit_optimizer_skip_embedding(self):
|
||||||
|
model = BasicTextGenerationModel(8, 4)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
for name_optim in ["rmsprop_bnb_8bit", "adamw_8bit"]:
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
report_to="none",
|
||||||
|
optim=name_optim,
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=model, args=args)
|
||||||
|
optimizer = trainer.create_optimizer()
|
||||||
|
modules = optimizer.mng.module_weight_config_triple
|
||||||
|
self.assertNotEqual(len(modules), 0)
|
||||||
|
module, name, config = modules[0]
|
||||||
|
self.assertIsInstance(module, torch.nn.Embedding)
|
||||||
|
self.assertEqual(name, "weight")
|
||||||
|
self.assertDictEqual(config, {"optim_bits": 32})
|
||||||
|
|||||||
Reference in New Issue
Block a user