Add support for GrokAdamW optimizer (#32521)

* add grokadamw

* reformat

* code review feedback, unit test

* reformat

* reformat
This commit is contained in:
Eric Hartford
2024-08-13 08:20:28 -04:00
committed by GitHub
parent b5016d5de7
commit 481e15604a
7 changed files with 107 additions and 0 deletions

View File

@@ -432,6 +432,57 @@ trainer = trl.SFTTrainer(
trainer.train()
```
## GrokAdamW optimizer
The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.
<Tip>
GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability.
</Tip>
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer:
```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer
# Load the IMDB dataset
train_dataset = datasets.load_dataset('imdb', split='train')
# Define the training arguments
args = TrainingArguments(
output_dir="./test-grokadamw",
max_steps=1000,
per_device_train_batch_size=4,
optim="grokadamw",
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-5,
save_strategy="no",
run_name="grokadamw-imdb",
)
# Load the model and tokenizer
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)
# Train the model
trainer.train()
```
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
## Accelerate and Trainer
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).

View File

@@ -76,6 +76,7 @@ from .utils import (
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
@@ -358,6 +359,13 @@ def require_lomo(test_case):
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
def require_grokadamw(test_case):
"""
Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
"""
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.

View File

@@ -153,6 +153,7 @@ from .utils import (
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
@@ -1442,6 +1443,23 @@ class Trainer:
optimizer_cls = Lomo
optimizer_kwargs.update({"model": model})
elif args.optim == OptimizerNames.GROKADAMW:
if not is_grokadamw_available():
raise ValueError("Please install grokadamw with `pip install grokadamw`")
from grokadamw import GrokAdamW
optimizer_cls = GrokAdamW
optimizer_kwargs.update(
{
"alpha_init": float(optim_args.get("alpha_init", 0.98)),
"lamb": float(optim_args.get("lamb", 2.0)),
"gamma": float(optim_args.get("gamma", 0.1)),
"grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
}
)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs

View File

@@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"
# Sometimes users will pass in a `str` repr of a dict in the CLI

View File

@@ -137,6 +137,7 @@ from .import_utils import (
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,

View File

@@ -101,6 +101,7 @@ _eetq_available = _is_package_available("eetq")
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
@@ -353,6 +354,10 @@ def is_lomo_available():
return _lomo_available
def is_grokadamw_available():
return _grokadamw_available
def is_pyctcdecode_available():
return _pyctcdecode_available

View File

@@ -62,6 +62,7 @@ from transformers.testing_utils import (
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
@@ -1366,6 +1367,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check this works
_ = trainer.train()
@require_grokadamw
@require_torch_gpu
def test_grokadamw():
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=2e-5,
logging_steps=5,
optim="grokadamw",
max_steps=20,
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]