schedulefree optimizers (#30079)

* schedulefree optimizers

* fix train instead of eval for optimizer

* fixes and update docs

* chore: lint

* add tests and drop overly-verbose _32bit suffix

* chore: lint

* fix for docs

* fix code review issues

* use duck-typing to avoid per-optimizer patches

* fixup style

* fixup style

* warn if incorrect accelerate version with schedule free

Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
Wing Lian
2024-09-09 03:51:39 -04:00
committed by GitHub
parent 60226fdc1d
commit 62aecd85ff
9 changed files with 124 additions and 0 deletions

View File

@@ -518,6 +518,51 @@ trainer.train()
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training. This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
## Schedule Free Optimizer
The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:
```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-schedulefree",
max_steps=1000,
per_device_train_batch_size=4,
optim="schedule_free_adamw",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="sfo-imdb",
)
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)
trainer.train()
```
## Accelerate and Trainer ## Accelerate and Trainer
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).

View File

@@ -163,6 +163,7 @@ _deps = [
"sacremoses", "sacremoses",
"safetensors>=0.4.1", "safetensors>=0.4.1",
"sagemaker>=2.31.0", "sagemaker>=2.31.0",
"schedulefree>=1.2.6",
"scikit-learn", "scikit-learn",
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`) "scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",

View File

@@ -69,6 +69,7 @@ deps = {
"sacremoses": "sacremoses", "sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.1", "safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0", "sagemaker": "sagemaker>=2.31.0",
"schedulefree": "schedulefree>=1.2.6",
"scikit-learn": "scikit-learn", "scikit-learn": "scikit-learn",
"scipy": "scipy<1.13.0", "scipy": "scipy<1.13.0",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",

View File

@@ -103,6 +103,7 @@ from .utils import (
is_rjieba_available, is_rjieba_available,
is_sacremoses_available, is_sacremoses_available,
is_safetensors_available, is_safetensors_available,
is_schedulefree_available,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
is_seqio_available, is_seqio_available,
@@ -370,6 +371,14 @@ def require_grokadamw(test_case):
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
def require_schedulefree(test_case):
"""
Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
https://github.com/facebookresearch/schedule_free
"""
return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
def require_cv2(test_case): def require_cv2(test_case):
""" """
Decorator marking a test that requires OpenCV. Decorator marking a test that requires OpenCV.

View File

@@ -161,6 +161,7 @@ from .utils import (
is_safetensors_available, is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_schedulefree_available,
is_torch_compile_available, is_torch_compile_available,
is_torch_mlu_available, is_torch_mlu_available,
is_torch_mps_available, is_torch_mps_available,
@@ -1488,6 +1489,36 @@ class Trainer:
optimizer_cls = AdamW4bit optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_ADAMW,
OptimizerNames.SCHEDULE_FREE_SGD,
]:
if not is_schedulefree_available():
raise ImportError(
"You need to install `schedulefree` in order to use schedulefree optimizers"
" install it with `pip install schedulefree`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
from schedulefree import AdamWScheduleFree, SGDScheduleFree
additional_optim_kwargs = {}
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
optimizer_cls = AdamWScheduleFree
additional_optim_kwargs = adam_kwargs
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
optimizer_cls = SGDScheduleFree
else:
raise ValueError("Invalid schedulefree optimizer")
additional_optim_kwargs["weight_decay"] = args.weight_decay
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
"r": float(optim_args.get("r", 0.0)),
}
)
optimizer_kwargs.update(additional_optim_kwargs)
else: else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs return optimizer_cls, optimizer_kwargs
@@ -3410,6 +3441,9 @@ class Trainer:
`torch.Tensor`: The tensor with training loss on this batch. `torch.Tensor`: The tensor with training loss on this batch.
""" """
model.train() model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
@@ -3960,6 +3994,8 @@ class Trainer:
logger.info(f" Batch size = {batch_size}") logger.info(f" Batch size = {batch_size}")
model.eval() model.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()
self.callback_handler.eval_dataloader = dataloader self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping. # Do this before wrapping.
@@ -4573,6 +4609,8 @@ class Trainer:
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
model.eval() model.eval()
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()
if args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None

View File

@@ -178,6 +178,8 @@ class OptimizerNames(ExplicitEnum):
LOMO = "lomo" LOMO = "lomo"
ADALOMO = "adalomo" ADALOMO = "adalomo"
GROKADAMW = "grokadamw" GROKADAMW = "grokadamw"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"
# Sometimes users will pass in a `str` repr of a dict in the CLI # Sometimes users will pass in a `str` repr of a dict in the CLI

View File

@@ -175,6 +175,7 @@ from .import_utils import (
is_safetensors_available, is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_schedulefree_available,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
is_seqio_available, is_seqio_available,

View File

@@ -103,6 +103,7 @@ _fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
_galore_torch_available = _is_package_available("galore_torch") _galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim") _lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw") _grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available = _is_package_available("schedulefree")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
@@ -364,6 +365,10 @@ def is_grokadamw_available():
return _grokadamw_available return _grokadamw_available
def is_schedulefree_available():
return _schedulefree_available
def is_pyctcdecode_available(): def is_pyctcdecode_available():
return _pyctcdecode_available return _pyctcdecode_available

View File

@@ -70,6 +70,7 @@ from transformers.testing_utils import (
require_peft, require_peft,
require_ray, require_ray,
require_safetensors, require_safetensors,
require_schedulefree,
require_sentencepiece, require_sentencepiece,
require_sigopt, require_sigopt,
require_tensorboard, require_tensorboard,
@@ -1442,6 +1443,27 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check this works # Check this works
_ = trainer.train() _ = trainer.train()
@require_schedulefree
@require_torch_gpu
def test_schedulefree_adam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
def test_galore_matched_modules(self): def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"] regex_patterns = [r".*.attn.*", r".*.mlp.*"]