Add AdEMAMix optimizer (#33682)

* Add AdEMAMix optimizer

* Fix test

* Update tests/trainer/test_trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Matthew Douglas
2024-09-25 13:07:21 -04:00
committed by GitHub
parent 61e98cb957
commit 196d35ccfc
3 changed files with 201 additions and 1 deletions

View File

@@ -15,6 +15,7 @@
import dataclasses
import gc
import importlib
import json
import math
import os
@@ -32,6 +33,7 @@ from unittest.mock import Mock, patch
import numpy as np
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
from packaging import version
from parameterized import parameterized
from requests.exceptions import HTTPError
@@ -1091,6 +1093,40 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check that it trains without errors
trainer.train()
@require_bitsandbytes
def test_ademamix_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
# Check that it trains without errors
trainer.train()
@require_bitsandbytes
def test_ademamix_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
# Check that it trains without errors
trainer.train()
@require_bitsandbytes
def test_rmsprop_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
@@ -4187,6 +4223,13 @@ if is_torch_available():
"lr": TrainingArguments.learning_rate,
}
default_ademamix_kwargs = {
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999),
"alpha": 5.0,
"eps": TrainingArguments.adam_epsilon,
"lr": TrainingArguments.learning_rate,
}
default_anyprecision_kwargs = {
"use_kahan_summation": False,
"momentum_dtype": torch.float32,
@@ -4291,6 +4334,36 @@ if is_torch_available():
)
)
if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"):
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
if is_torchdistx_available():
import torchdistx
@@ -4420,6 +4493,62 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
default_adam_kwargs,
)
def test_bnb_ademamix(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_ademamix8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_paged_ademamix(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_paged_ademamix8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)
def test_bnb_lion(self):
mock = Mock()
modules = {
@@ -4503,6 +4632,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_ademamix_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_ademamix8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_ademamix_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_ademamix8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_lion_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")