Add AdEMAMix optimizer (#33682)
* Add AdEMAMix optimizer * Fix test * Update tests/trainer/test_trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -1237,6 +1237,10 @@ class Trainer:
|
|||||||
OptimizerNames.ADAMW_8BIT,
|
OptimizerNames.ADAMW_8BIT,
|
||||||
OptimizerNames.PAGED_ADAMW,
|
OptimizerNames.PAGED_ADAMW,
|
||||||
OptimizerNames.PAGED_ADAMW_8BIT,
|
OptimizerNames.PAGED_ADAMW_8BIT,
|
||||||
|
OptimizerNames.ADEMAMIX,
|
||||||
|
OptimizerNames.ADEMAMIX_8BIT,
|
||||||
|
OptimizerNames.PAGED_ADEMAMIX,
|
||||||
|
OptimizerNames.PAGED_ADEMAMIX_8BIT,
|
||||||
OptimizerNames.LION,
|
OptimizerNames.LION,
|
||||||
OptimizerNames.LION_8BIT,
|
OptimizerNames.LION_8BIT,
|
||||||
OptimizerNames.PAGED_LION,
|
OptimizerNames.PAGED_LION,
|
||||||
@@ -1266,6 +1270,33 @@ class Trainer:
|
|||||||
# Above we pass all `adam_kwargs` to the optimizer, here
|
# Above we pass all `adam_kwargs` to the optimizer, here
|
||||||
# we only pass `optim_args` which can be passed by the user.
|
# we only pass `optim_args` which can be passed by the user.
|
||||||
additional_optim_kwargs = optim_args
|
additional_optim_kwargs = optim_args
|
||||||
|
elif "ademamix" in args.optim:
|
||||||
|
if is_bitsandbytes_available() and version.parse(
|
||||||
|
importlib.metadata.version("bitsandbytes")
|
||||||
|
) < version.parse("0.44.0"):
|
||||||
|
raise ValueError(
|
||||||
|
"The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
|
||||||
|
"Please install `bitsandbytes` >= 0.44.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
from bitsandbytes.optim import AdEMAMix
|
||||||
|
|
||||||
|
optimizer_cls = AdEMAMix
|
||||||
|
additional_optim_kwargs = {
|
||||||
|
"betas": (
|
||||||
|
float(optim_args.get("beta1", args.adam_beta1)),
|
||||||
|
float(optim_args.get("beta2", args.adam_beta2)),
|
||||||
|
float(optim_args.get("beta3", 0.9999)),
|
||||||
|
),
|
||||||
|
"alpha": float(optim_args.get("alpha", 5.0)),
|
||||||
|
"eps": float(optim_args.get("eps", args.adam_epsilon)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "t_alpha" in optim_args:
|
||||||
|
additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
|
||||||
|
|
||||||
|
if "t_beta3" in optim_args:
|
||||||
|
additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
|
||||||
|
|
||||||
bnb_kwargs = {"optim_bits": optim_bits}
|
bnb_kwargs = {"optim_bits": optim_bits}
|
||||||
if "rmsprop" not in args.optim:
|
if "rmsprop" not in args.optim:
|
||||||
|
|||||||
@@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
ADAFACTOR = "adafactor"
|
ADAFACTOR = "adafactor"
|
||||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||||
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
|
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
|
||||||
|
ADEMAMIX = "ademamix"
|
||||||
SGD = "sgd"
|
SGD = "sgd"
|
||||||
ADAGRAD = "adagrad"
|
ADAGRAD = "adagrad"
|
||||||
ADAMW_BNB = "adamw_bnb_8bit"
|
ADAMW_BNB = "adamw_bnb_8bit"
|
||||||
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
||||||
|
ADEMAMIX_8BIT = "ademamix_8bit"
|
||||||
LION_8BIT = "lion_8bit"
|
LION_8BIT = "lion_8bit"
|
||||||
LION = "lion_32bit"
|
LION = "lion_32bit"
|
||||||
PAGED_ADAMW = "paged_adamw_32bit"
|
PAGED_ADAMW = "paged_adamw_32bit"
|
||||||
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
||||||
|
PAGED_ADEMAMIX = "paged_ademamix_32bit"
|
||||||
|
PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
|
||||||
PAGED_LION = "paged_lion_32bit"
|
PAGED_LION = "paged_lion_32bit"
|
||||||
PAGED_LION_8BIT = "paged_lion_8bit"
|
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||||
RMSPROP = "rmsprop"
|
RMSPROP = "rmsprop"
|
||||||
@@ -618,7 +622,7 @@ class TrainingArguments:
|
|||||||
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
|
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
|
||||||
for a full list of optimizers.
|
for a full list of optimizers.
|
||||||
optim_args (`str`, *optional*):
|
optim_args (`str`, *optional*):
|
||||||
Optional arguments that are supplied to AnyPrecisionAdamW.
|
Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
|
||||||
group_by_length (`bool`, *optional*, defaults to `False`):
|
group_by_length (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
|
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
|
||||||
padding applied and be more efficient). Only useful if applying dynamic padding.
|
padding applied and be more efficient). Only useful if applying dynamic padding.
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import gc
|
import gc
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -32,6 +33,7 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
|
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
|
||||||
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
@@ -1091,6 +1093,40 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# Check that it trains without errors
|
# Check that it trains without errors
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_ademamix_bnb(self):
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix"
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check that it trains without errors
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_ademamix_bnb_8bit(self):
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit"
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check that it trains without errors
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_rmsprop_bnb_8bit(self):
|
def test_rmsprop_bnb_8bit(self):
|
||||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
@@ -4187,6 +4223,13 @@ if is_torch_available():
|
|||||||
"lr": TrainingArguments.learning_rate,
|
"lr": TrainingArguments.learning_rate,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
default_ademamix_kwargs = {
|
||||||
|
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999),
|
||||||
|
"alpha": 5.0,
|
||||||
|
"eps": TrainingArguments.adam_epsilon,
|
||||||
|
"lr": TrainingArguments.learning_rate,
|
||||||
|
}
|
||||||
|
|
||||||
default_anyprecision_kwargs = {
|
default_anyprecision_kwargs = {
|
||||||
"use_kahan_summation": False,
|
"use_kahan_summation": False,
|
||||||
"momentum_dtype": torch.float32,
|
"momentum_dtype": torch.float32,
|
||||||
@@ -4291,6 +4334,36 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"):
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
|
||||||
|
bnb.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
|
||||||
|
bnb.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
|
||||||
|
bnb.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
optim_test_params.append(
|
||||||
|
(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
|
||||||
|
bnb.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_torchdistx_available():
|
if is_torchdistx_available():
|
||||||
import torchdistx
|
import torchdistx
|
||||||
|
|
||||||
@@ -4420,6 +4493,62 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
default_adam_kwargs,
|
default_adam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_bnb_ademamix(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
|
||||||
|
mock.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_ademamix8bit(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
|
||||||
|
mock.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_ademamix(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
|
||||||
|
mock.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bnb_paged_ademamix8bit(self):
|
||||||
|
mock = Mock()
|
||||||
|
modules = {
|
||||||
|
"bitsandbytes": mock,
|
||||||
|
"bitsandbytes.optim": mock.optim,
|
||||||
|
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
|
||||||
|
}
|
||||||
|
with patch.dict("sys.modules", modules):
|
||||||
|
self.check_optim_and_kwargs(
|
||||||
|
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
|
||||||
|
mock.optim.AdEMAMix,
|
||||||
|
default_ademamix_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def test_bnb_lion(self):
|
def test_bnb_lion(self):
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
modules = {
|
modules = {
|
||||||
@@ -4503,6 +4632,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_ademamix_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if `bitsandbytes` is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_ademamix8bit_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if `bitsandbytes` is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_paged_ademamix_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if `bitsandbytes` is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
|
def test_bnb_paged_ademamix8bit_no_bnb(self):
|
||||||
|
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None")
|
||||||
|
|
||||||
|
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
|
||||||
|
# bnb will fail even if `bitsandbytes` is installed.
|
||||||
|
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||||
|
|
||||||
def test_bnb_paged_lion_no_bnb(self):
|
def test_bnb_paged_lion_no_bnb(self):
|
||||||
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")
|
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user