Add support for GrokAdamW optimizer (#32521)
* add grokadamw * reformat * code review feedback, unit test * reformat * reformat
This commit is contained in:
@@ -432,6 +432,57 @@ trainer = trl.SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## GrokAdamW optimizer
|
||||
|
||||
The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.
|
||||
|
||||
<Tip>
|
||||
|
||||
GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability.
|
||||
|
||||
</Tip>
|
||||
|
||||
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer
|
||||
|
||||
# Load the IMDB dataset
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
# Define the training arguments
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-grokadamw",
|
||||
max_steps=1000,
|
||||
per_device_train_batch_size=4,
|
||||
optim="grokadamw",
|
||||
logging_strategy="steps",
|
||||
logging_steps=1,
|
||||
learning_rate=2e-5,
|
||||
save_strategy="no",
|
||||
run_name="grokadamw-imdb",
|
||||
)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model_id = "google/gemma-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
|
||||
|
||||
# Initialize the Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
|
||||
|
||||
## Accelerate and Trainer
|
||||
|
||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||
|
||||
@@ -76,6 +76,7 @@ from .utils import (
|
||||
is_g2p_en_available,
|
||||
is_galore_torch_available,
|
||||
is_gguf_available,
|
||||
is_grokadamw_available,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
is_jinja_available,
|
||||
@@ -358,6 +359,13 @@ def require_lomo(test_case):
|
||||
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
|
||||
|
||||
|
||||
def require_grokadamw(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
|
||||
|
||||
|
||||
def require_cv2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires OpenCV.
|
||||
|
||||
@@ -153,6 +153,7 @@ from .utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_datasets_available,
|
||||
is_galore_torch_available,
|
||||
is_grokadamw_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_lomo_available,
|
||||
@@ -1442,6 +1443,23 @@ class Trainer:
|
||||
optimizer_cls = Lomo
|
||||
|
||||
optimizer_kwargs.update({"model": model})
|
||||
elif args.optim == OptimizerNames.GROKADAMW:
|
||||
if not is_grokadamw_available():
|
||||
raise ValueError("Please install grokadamw with `pip install grokadamw`")
|
||||
|
||||
from grokadamw import GrokAdamW
|
||||
|
||||
optimizer_cls = GrokAdamW
|
||||
optimizer_kwargs.update(
|
||||
{
|
||||
"alpha_init": float(optim_args.get("alpha_init", 0.98)),
|
||||
"lamb": float(optim_args.get("lamb", 2.0)),
|
||||
"gamma": float(optim_args.get("gamma", 0.1)),
|
||||
"grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
|
||||
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
||||
@@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum):
|
||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||
LOMO = "lomo"
|
||||
ADALOMO = "adalomo"
|
||||
GROKADAMW = "grokadamw"
|
||||
|
||||
|
||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||
|
||||
@@ -137,6 +137,7 @@ from .import_utils import (
|
||||
is_g2p_en_available,
|
||||
is_galore_torch_available,
|
||||
is_gguf_available,
|
||||
is_grokadamw_available,
|
||||
is_hqq_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
|
||||
@@ -101,6 +101,7 @@ _eetq_available = _is_package_available("eetq")
|
||||
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
_lomo_available = _is_package_available("lomo_optim")
|
||||
_grokadamw_available = _is_package_available("grokadamw")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||
@@ -353,6 +354,10 @@ def is_lomo_available():
|
||||
return _lomo_available
|
||||
|
||||
|
||||
def is_grokadamw_available():
|
||||
return _grokadamw_available
|
||||
|
||||
|
||||
def is_pyctcdecode_available():
|
||||
return _pyctcdecode_available
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
require_grokadamw,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_lomo,
|
||||
require_optuna,
|
||||
@@ -1366,6 +1367,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_grokadamw
|
||||
@require_torch_gpu
|
||||
def test_grokadamw():
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=5,
|
||||
optim="grokadamw",
|
||||
max_steps=20,
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user