Paged Optimizer + Lion Optimizer for Trainer (#23217)
* Added lion and paged optimizers and made original tests pass. * Added tests for paged and lion optimizers. * Added and fixed optimizer tests. * Style and quality checks. --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
@@ -1170,6 +1170,38 @@ class Trainer:
|
|||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
|
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
|
||||||
|
elif args.optim in [
|
||||||
|
OptimizerNames.ADAMW_BNB,
|
||||||
|
OptimizerNames.ADAMW_8BIT,
|
||||||
|
OptimizerNames.PAGED_ADAMW,
|
||||||
|
OptimizerNames.PAGED_ADAMW_8BIT,
|
||||||
|
OptimizerNames.LION,
|
||||||
|
OptimizerNames.LION_8BIT,
|
||||||
|
OptimizerNames.PAGED_LION,
|
||||||
|
OptimizerNames.PAGED_LION_8BIT,
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
from bitsandbytes.optim import AdamW, Lion
|
||||||
|
|
||||||
|
is_paged = False
|
||||||
|
optim_bits = 32
|
||||||
|
optimizer_cls = None
|
||||||
|
additional_optim_kwargs = adam_kwargs
|
||||||
|
if "paged" in args.optim:
|
||||||
|
is_paged = True
|
||||||
|
if "8bit" in args.optim:
|
||||||
|
optim_bits = 8
|
||||||
|
if "adam" in args.optim:
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
elif "lion" in args.optim:
|
||||||
|
optimizer_cls = Lion
|
||||||
|
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
|
||||||
|
|
||||||
|
bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
|
||||||
|
optimizer_kwargs.update(additional_optim_kwargs)
|
||||||
|
optimizer_kwargs.update(bnb_kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
|
||||||
elif args.optim == OptimizerNames.ADAMW_BNB:
|
elif args.optim == OptimizerNames.ADAMW_BNB:
|
||||||
try:
|
try:
|
||||||
from bitsandbytes.optim import Adam8bit
|
from bitsandbytes.optim import Adam8bit
|
||||||
|
|||||||
@@ -139,10 +139,17 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||||
ADAFACTOR = "adafactor"
|
ADAFACTOR = "adafactor"
|
||||||
ADAMW_BNB = "adamw_bnb_8bit"
|
|
||||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||||
SGD = "sgd"
|
SGD = "sgd"
|
||||||
ADAGRAD = "adagrad"
|
ADAGRAD = "adagrad"
|
||||||
|
ADAMW_BNB = "adamw_bnb_8bit"
|
||||||
|
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
||||||
|
LION_8BIT = "lion_8bit"
|
||||||
|
LION = "lion_32bit"
|
||||||
|
PAGED_ADAMW = "paged_adamw_32bit"
|
||||||
|
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
||||||
|
PAGED_LION = "paged_lion_32bit"
|
||||||
|
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -2474,6 +2474,11 @@ if is_torch_available():
|
|||||||
"lr": TrainingArguments.learning_rate,
|
"lr": TrainingArguments.learning_rate,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
default_lion_kwargs = {
|
||||||
|
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
|
||||||
|
"lr": TrainingArguments.learning_rate,
|
||||||
|
}
|
||||||
|
|
||||||
default_anyprecision_kwargs = {
|
default_anyprecision_kwargs = {
|
||||||
"use_kahan_summation": False,
|
"use_kahan_summation": False,
|
||||||
"momentum_dtype": torch.float32,
|
"momentum_dtype": torch.float32,
|
||||||
@@ -2525,11 +2530,59 @@ if is_torch_available():
|
|||||||
optim_test_params.append(
|
optim_test_params.append(
|
||||||
(
|
(
|
||||||
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
||||||
bnb.optim.Adam8bit,
|
bnb.optim.AdamW,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"),
|
||||||
|
bnb.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"),
|
||||||
|
bnb.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"),
|
||||||
|
bnb.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.LION, output_dir="None"),
|
||||||
|
bnb.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"),
|
||||||
|
bnb.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"),
|
||||||
|
bnb.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_torchdistx_available():
|
if is_torchdistx_available():
|
||||||
import torchdistx
|
import torchdistx
|
||||||
|
|
||||||
@@ -2598,15 +2651,113 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
modules = {
|
modules = {
|
||||||
"bitsandbytes": mock,
|
"bitsandbytes": mock,
|
||||||
"bitsandbytes.optim": mock.optim,
|
"bitsandbytes.optim": mock.optim,
|
||||||
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
|
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||||
}
|
}
|
||||||
with patch.dict("sys.modules", modules):
|
with patch.dict("sys.modules", modules):
|
||||||
self.check_optim_and_kwargs(
|
self.check_optim_and_kwargs(
|
||||||
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
|
||||||
mock.optim.Adam8bit,
|
mock.optim.AdamW,
|
||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_adam8bit_alias(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"),
|
||||||
|
mock.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_adam(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"),
|
||||||
|
mock.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_adam8bit(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"),
|
||||||
|
mock.optim.AdamW,
|
||||||
|
default_adam_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_lion(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.LION, output_dir="None"),
|
||||||
|
mock.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_lion8bit(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"),
|
||||||
|
mock.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_lion8bit(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"),
|
||||||
|
mock.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_lion(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.Lion": mock.optim.Lion,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None"),
|
||||||
|
mock.optim.Lion,
|
||||||
|
default_lion_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def test_bnb_adam8bit_no_bnb(self):
|
def test_bnb_adam8bit_no_bnb(self):
|
||||||
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
|
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
|
||||||
|
|
||||||
@@ -2616,6 +2767,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_paged_adam_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if bnb is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_paged_adam8bit_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if bnb is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_paged_lion_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if bnb is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_paged_lion8bit_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if bnb is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
def test_anyprecision_adamw(self):
|
def test_anyprecision_adamw(self):
|
||||||
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
|
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
|
||||||
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
|
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
|
||||||
|
|||||||
Reference in New Issue
Block a user