schedulefree optimizers (#30079)
* schedulefree optimizers * fix train instead of eval for optimizer * fixes and update docs * chore: lint * add tests and drop overly-verbose _32bit suffix * chore: lint * fix for docs * fix code review issues * use duck-typing to avoid per-optimizer patches * fixup style * fixup style * warn if incorrect accelerate version with schedule free Co-authored-by: Aman Gupta Karmani <aman@tmm1.net> --------- Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
@@ -518,6 +518,51 @@ trainer.train()
|
||||
|
||||
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
|
||||
|
||||
## Schedule Free Optimizer
|
||||
|
||||
The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
|
||||
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
|
||||
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.
|
||||
|
||||
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
|
||||
import trl
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-schedulefree",
|
||||
max_steps=1000,
|
||||
per_device_train_batch_size=4,
|
||||
optim="schedule_free_adamw",
|
||||
gradient_checkpointing=True,
|
||||
logging_strategy="steps",
|
||||
logging_steps=1,
|
||||
learning_rate=2e-6,
|
||||
save_strategy="no",
|
||||
run_name="sfo-imdb",
|
||||
)
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
|
||||
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=1024,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Accelerate and Trainer
|
||||
|
||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||
|
||||
1
setup.py
1
setup.py
@@ -163,6 +163,7 @@ _deps = [
|
||||
"sacremoses",
|
||||
"safetensors>=0.4.1",
|
||||
"sagemaker>=2.31.0",
|
||||
"schedulefree>=1.2.6",
|
||||
"scikit-learn",
|
||||
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
|
||||
@@ -69,6 +69,7 @@ deps = {
|
||||
"sacremoses": "sacremoses",
|
||||
"safetensors": "safetensors>=0.4.1",
|
||||
"sagemaker": "sagemaker>=2.31.0",
|
||||
"schedulefree": "schedulefree>=1.2.6",
|
||||
"scikit-learn": "scikit-learn",
|
||||
"scipy": "scipy<1.13.0",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
|
||||
@@ -103,6 +103,7 @@ from .utils import (
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
is_safetensors_available,
|
||||
is_schedulefree_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_seqio_available,
|
||||
@@ -370,6 +371,14 @@ def require_grokadamw(test_case):
|
||||
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
|
||||
|
||||
|
||||
def require_schedulefree(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
|
||||
https://github.com/facebookresearch/schedule_free
|
||||
"""
|
||||
return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
|
||||
|
||||
|
||||
def require_cv2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires OpenCV.
|
||||
|
||||
@@ -161,6 +161,7 @@ from .utils import (
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_schedulefree_available,
|
||||
is_torch_compile_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
@@ -1488,6 +1489,36 @@ class Trainer:
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif args.optim in [
|
||||
OptimizerNames.SCHEDULE_FREE_ADAMW,
|
||||
OptimizerNames.SCHEDULE_FREE_SGD,
|
||||
]:
|
||||
if not is_schedulefree_available():
|
||||
raise ImportError(
|
||||
"You need to install `schedulefree` in order to use schedulefree optimizers"
|
||||
" install it with `pip install schedulefree`"
|
||||
)
|
||||
if not is_accelerate_available("0.30.0"):
|
||||
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
|
||||
from schedulefree import AdamWScheduleFree, SGDScheduleFree
|
||||
|
||||
additional_optim_kwargs = {}
|
||||
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
|
||||
optimizer_cls = AdamWScheduleFree
|
||||
additional_optim_kwargs = adam_kwargs
|
||||
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
|
||||
optimizer_cls = SGDScheduleFree
|
||||
else:
|
||||
raise ValueError("Invalid schedulefree optimizer")
|
||||
additional_optim_kwargs["weight_decay"] = args.weight_decay
|
||||
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
|
||||
additional_optim_kwargs.update(
|
||||
{
|
||||
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
|
||||
"r": float(optim_args.get("r", 0.0)),
|
||||
}
|
||||
)
|
||||
optimizer_kwargs.update(additional_optim_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
@@ -3410,6 +3441,9 @@ class Trainer:
|
||||
`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
model.train()
|
||||
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
||||
self.optimizer.train()
|
||||
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if is_sagemaker_mp_enabled():
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||
@@ -3960,6 +3994,8 @@ class Trainer:
|
||||
logger.info(f" Batch size = {batch_size}")
|
||||
|
||||
model.eval()
|
||||
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
|
||||
self.optimizer.eval()
|
||||
|
||||
self.callback_handler.eval_dataloader = dataloader
|
||||
# Do this before wrapping.
|
||||
@@ -4573,6 +4609,8 @@ class Trainer:
|
||||
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
|
||||
|
||||
model.eval()
|
||||
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
|
||||
self.optimizer.eval()
|
||||
|
||||
if args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
@@ -178,6 +178,8 @@ class OptimizerNames(ExplicitEnum):
|
||||
LOMO = "lomo"
|
||||
ADALOMO = "adalomo"
|
||||
GROKADAMW = "grokadamw"
|
||||
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
|
||||
SCHEDULE_FREE_SGD = "schedule_free_sgd"
|
||||
|
||||
|
||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||
|
||||
@@ -175,6 +175,7 @@ from .import_utils import (
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_schedulefree_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_seqio_available,
|
||||
|
||||
@@ -103,6 +103,7 @@ _fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
_lomo_available = _is_package_available("lomo_optim")
|
||||
_grokadamw_available = _is_package_available("grokadamw")
|
||||
_schedulefree_available = _is_package_available("schedulefree")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||
@@ -364,6 +365,10 @@ def is_grokadamw_available():
|
||||
return _grokadamw_available
|
||||
|
||||
|
||||
def is_schedulefree_available():
|
||||
return _schedulefree_available
|
||||
|
||||
|
||||
def is_pyctcdecode_available():
|
||||
return _pyctcdecode_available
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ from transformers.testing_utils import (
|
||||
require_peft,
|
||||
require_ray,
|
||||
require_safetensors,
|
||||
require_schedulefree,
|
||||
require_sentencepiece,
|
||||
require_sigopt,
|
||||
require_tensorboard,
|
||||
@@ -1442,6 +1443,27 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_schedulefree
|
||||
@require_torch_gpu
|
||||
def test_schedulefree_adam(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="schedule_free_adamw",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user