Add support for GrokAdamW optimizer (#32521)
* add grokadamw * reformat * code review feedback, unit test * reformat * reformat
This commit is contained in:
@@ -432,6 +432,57 @@ trainer = trl.SFTTrainer(
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## GrokAdamW optimizer
|
||||||
|
|
||||||
|
The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer
|
||||||
|
|
||||||
|
# Load the IMDB dataset
|
||||||
|
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||||
|
|
||||||
|
# Define the training arguments
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir="./test-grokadamw",
|
||||||
|
max_steps=1000,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
optim="grokadamw",
|
||||||
|
logging_strategy="steps",
|
||||||
|
logging_steps=1,
|
||||||
|
learning_rate=2e-5,
|
||||||
|
save_strategy="no",
|
||||||
|
run_name="grokadamw-imdb",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the model and tokenizer
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
|
||||||
|
|
||||||
|
# Initialize the Trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
|
||||||
|
|
||||||
## Accelerate and Trainer
|
## Accelerate and Trainer
|
||||||
|
|
||||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ from .utils import (
|
|||||||
is_g2p_en_available,
|
is_g2p_en_available,
|
||||||
is_galore_torch_available,
|
is_galore_torch_available,
|
||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
|
is_grokadamw_available,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jieba_available,
|
is_jieba_available,
|
||||||
is_jinja_available,
|
is_jinja_available,
|
||||||
@@ -358,6 +359,13 @@ def require_lomo(test_case):
|
|||||||
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
|
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_grokadamw(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_cv2(test_case):
|
def require_cv2(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires OpenCV.
|
Decorator marking a test that requires OpenCV.
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ from .utils import (
|
|||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_galore_torch_available,
|
is_galore_torch_available,
|
||||||
|
is_grokadamw_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_lomo_available,
|
is_lomo_available,
|
||||||
@@ -1442,6 +1443,23 @@ class Trainer:
|
|||||||
optimizer_cls = Lomo
|
optimizer_cls = Lomo
|
||||||
|
|
||||||
optimizer_kwargs.update({"model": model})
|
optimizer_kwargs.update({"model": model})
|
||||||
|
elif args.optim == OptimizerNames.GROKADAMW:
|
||||||
|
if not is_grokadamw_available():
|
||||||
|
raise ValueError("Please install grokadamw with `pip install grokadamw`")
|
||||||
|
|
||||||
|
from grokadamw import GrokAdamW
|
||||||
|
|
||||||
|
optimizer_cls = GrokAdamW
|
||||||
|
optimizer_kwargs.update(
|
||||||
|
{
|
||||||
|
"alpha_init": float(optim_args.get("alpha_init", 0.98)),
|
||||||
|
"lamb": float(optim_args.get("lamb", 2.0)),
|
||||||
|
"gamma": float(optim_args.get("gamma", 0.1)),
|
||||||
|
"grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
|
||||||
|
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||||
return optimizer_cls, optimizer_kwargs
|
return optimizer_cls, optimizer_kwargs
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||||
LOMO = "lomo"
|
LOMO = "lomo"
|
||||||
ADALOMO = "adalomo"
|
ADALOMO = "adalomo"
|
||||||
|
GROKADAMW = "grokadamw"
|
||||||
|
|
||||||
|
|
||||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ from .import_utils import (
|
|||||||
is_g2p_en_available,
|
is_g2p_en_available,
|
||||||
is_galore_torch_available,
|
is_galore_torch_available,
|
||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
|
is_grokadamw_available,
|
||||||
is_hqq_available,
|
is_hqq_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ _eetq_available = _is_package_available("eetq")
|
|||||||
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
||||||
_galore_torch_available = _is_package_available("galore_torch")
|
_galore_torch_available = _is_package_available("galore_torch")
|
||||||
_lomo_available = _is_package_available("lomo_optim")
|
_lomo_available = _is_package_available("lomo_optim")
|
||||||
|
_grokadamw_available = _is_package_available("grokadamw")
|
||||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||||
@@ -353,6 +354,10 @@ def is_lomo_available():
|
|||||||
return _lomo_available
|
return _lomo_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_grokadamw_available():
|
||||||
|
return _grokadamw_available
|
||||||
|
|
||||||
|
|
||||||
def is_pyctcdecode_available():
|
def is_pyctcdecode_available():
|
||||||
return _pyctcdecode_available
|
return _pyctcdecode_available
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
|||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_galore_torch,
|
require_galore_torch,
|
||||||
|
require_grokadamw,
|
||||||
require_intel_extension_for_pytorch,
|
require_intel_extension_for_pytorch,
|
||||||
require_lomo,
|
require_lomo,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
@@ -1366,6 +1367,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# Check this works
|
# Check this works
|
||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_grokadamw
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_grokadamw():
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=2e-5,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="grokadamw",
|
||||||
|
max_steps=20,
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
def test_galore_matched_modules(self):
|
def test_galore_matched_modules(self):
|
||||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user