[fix] make legacy bnb code work (#37331)
* [fix] make legacy bnb code work * [fix] use get with default instead of getter * add test for bnb 8bit optim skip embed * [fix] style * add require annotation of bnb --------- Co-authored-by: jaycha <jaycha@ncsoft.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -5962,3 +5962,22 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_bnb_8bit_optimizer_skip_embedding(self):
|
||||
model = BasicTextGenerationModel(8, 4)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for name_optim in ["rmsprop_bnb_8bit", "adamw_8bit"]:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
report_to="none",
|
||||
optim=name_optim,
|
||||
)
|
||||
trainer = Trainer(model=model, args=args)
|
||||
optimizer = trainer.create_optimizer()
|
||||
modules = optimizer.mng.module_weight_config_triple
|
||||
self.assertNotEqual(len(modules), 0)
|
||||
module, name, config = modules[0]
|
||||
self.assertIsInstance(module, torch.nn.Embedding)
|
||||
self.assertEqual(name, "weight")
|
||||
self.assertDictEqual(config, {"optim_bits": 32})
|
||||
|
||||
Reference in New Issue
Block a user