diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 916ae6428e..37d8baf3d7 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -432,6 +432,57 @@ trainer = trl.SFTTrainer( trainer.train() ``` +## GrokAdamW optimizer + +The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`. + + + +GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability. + + + +Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer: + +```python +import torch +import datasets +from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer + +# Load the IMDB dataset +train_dataset = datasets.load_dataset('imdb', split='train') + +# Define the training arguments +args = TrainingArguments( + output_dir="./test-grokadamw", + max_steps=1000, + per_device_train_batch_size=4, + optim="grokadamw", + logging_strategy="steps", + logging_steps=1, + learning_rate=2e-5, + save_strategy="no", + run_name="grokadamw-imdb", +) + +# Load the model and tokenizer +model_id = "google/gemma-2b" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0) + +# Initialize the Trainer +trainer = Trainer( + model=model, + args=args, + train_dataset=train_dataset, +) + +# Train the model +trainer.train() +``` + +This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training. + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 608e278ecf..6b04ed7426 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -76,6 +76,7 @@ from .utils import ( is_g2p_en_available, is_galore_torch_available, is_gguf_available, + is_grokadamw_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -358,6 +359,13 @@ def require_lomo(test_case): return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case) +def require_grokadamw(test_case): + """ + Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed. + """ + return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) + + def require_cv2(test_case): """ Decorator marking a test that requires OpenCV. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e74b463e8d..094f058469 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -153,6 +153,7 @@ from .utils import ( is_bitsandbytes_available, is_datasets_available, is_galore_torch_available, + is_grokadamw_available, is_in_notebook, is_ipex_available, is_lomo_available, @@ -1442,6 +1443,23 @@ class Trainer: optimizer_cls = Lomo optimizer_kwargs.update({"model": model}) + elif args.optim == OptimizerNames.GROKADAMW: + if not is_grokadamw_available(): + raise ValueError("Please install grokadamw with `pip install grokadamw`") + + from grokadamw import GrokAdamW + + optimizer_cls = GrokAdamW + optimizer_kwargs.update( + { + "alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + } + ) + else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 57605fd945..1058b8356b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum): GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" LOMO = "lomo" ADALOMO = "adalomo" + GROKADAMW = "grokadamw" # Sometimes users will pass in a `str` repr of a dict in the CLI diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index efe473a6cd..a8aa670c07 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -137,6 +137,7 @@ from .import_utils import ( is_g2p_en_available, is_galore_torch_available, is_gguf_available, + is_grokadamw_available, is_hqq_available, is_in_notebook, is_ipex_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index b9936a038a..97d3a5501b 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -101,6 +101,7 @@ _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") +_grokadamw_available = _is_package_available("grokadamw") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -353,6 +354,10 @@ def is_lomo_available(): return _lomo_available +def is_grokadamw_available(): + return _grokadamw_available + + def is_pyctcdecode_available(): return _pyctcdecode_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7378a597c3..ca133a277c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,6 +62,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_deepspeed, require_galore_torch, + require_grokadamw, require_intel_extension_for_pytorch, require_lomo, require_optuna, @@ -1366,6 +1367,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # Check this works _ = trainer.train() + @require_grokadamw + @require_torch_gpu + def test_grokadamw(): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=2e-5, + logging_steps=5, + optim="grokadamw", + max_steps=20, + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + def test_galore_matched_modules(self): regex_patterns = [r".*.attn.*", r".*.mlp.*"]