Paged Optimizer + Lion Optimizer for Trainer (#23217)
* Added lion and paged optimizers and made original tests pass. * Added tests for paged and lion optimizers. * Added and fixed optimizer tests. * Style and quality checks. --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
@@ -2474,6 +2474,11 @@ if is_torch_available():
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
}
|
||||
|
||||
default_lion_kwargs = {
|
||||
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
|
||||
"lr": TrainingArguments.learning_rate,
|
||||
}
|
||||
|
||||
default_anyprecision_kwargs = {
|
||||
"use_kahan_summation": False,
|
||||
"momentum_dtype": torch.float32,
|
||||
@@ -2525,11 +2530,59 @@ if is_torch_available():
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
||||
bnb.optim.Adam8bit,
|
||||
bnb.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"),
|
||||
bnb.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"),
|
||||
bnb.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"),
|
||||
bnb.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.LION, output_dir="None"),
|
||||
bnb.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"),
|
||||
bnb.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
optim_test_params.append(
|
||||
(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"),
|
||||
bnb.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if is_torchdistx_available():
|
||||
import torchdistx
|
||||
|
||||
@@ -2598,15 +2651,113 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
|
||||
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
||||
mock.optim.Adam8bit,
|
||||
mock.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_paged_adam8bit_alias(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"),
|
||||
mock.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_paged_adam(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"),
|
||||
mock.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_paged_adam8bit(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"),
|
||||
mock.optim.AdamW,
|
||||
default_adam_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_lion(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.LION, output_dir="None"),
|
||||
mock.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_lion8bit(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"),
|
||||
mock.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_paged_lion8bit(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"),
|
||||
mock.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_paged_lion(self):
|
||||
mock = Mock()
|
||||
modules = {
|
||||
"bitsandbytes": mock,
|
||||
"bitsandbytes.optim": mock.optim,
|
||||
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||
}
|
||||
with patch.dict("sys.modules", modules):
|
||||
self.check_optim_and_kwargs(
|
||||
TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None"),
|
||||
mock.optim.Lion,
|
||||
default_lion_kwargs,
|
||||
)
|
||||
|
||||
def test_bnb_adam8bit_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
|
||||
|
||||
@@ -2616,6 +2767,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_bnb_paged_adam_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None")
|
||||
|
||||
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||
# bnb will fail even if bnb is installed.
|
||||
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_bnb_paged_adam8bit_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None")
|
||||
|
||||
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||
# bnb will fail even if bnb is installed.
|
||||
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_bnb_paged_lion_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")
|
||||
|
||||
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||
# bnb will fail even if bnb is installed.
|
||||
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_bnb_paged_lion8bit_no_bnb(self):
|
||||
args = TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None")
|
||||
|
||||
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||
# bnb will fail even if bnb is installed.
|
||||
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
def test_anyprecision_adamw(self):
|
||||
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
|
||||
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
|
||||
|
||||
Reference in New Issue
Block a user