FEAT / Trainer: LOMO optimizer support (#30178)
* add V1 - adalomo not working yet * add todo docs + refactor from comments * adjust LR * add docs * add more elaborated test * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix * push * add accelerate check * fix DDP case * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix * init kwargs * safely add attribute * revert to enum logic * Update src/transformers/trainer.py --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -382,6 +382,56 @@ trainer.train()
|
|||||||
|
|
||||||
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
|
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
|
||||||
|
|
||||||
|
## LOMO optimizer
|
||||||
|
|
||||||
|
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
|
||||||
|
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are `"lomo"` and `"adalomo"`. First either install LOMO from pypi `pip install lomo-optim` or install it from source with `pip install git+https://github.com/OpenLMLab/LOMO.git`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
According to the authors, it is recommended to use `AdaLomo` without `grad_norm` to get better performance and higher throughput.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import trl
|
||||||
|
|
||||||
|
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir="./test-lomo",
|
||||||
|
max_steps=1000,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
optim="adalomo",
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
logging_strategy="steps",
|
||||||
|
logging_steps=1,
|
||||||
|
learning_rate=2e-6,
|
||||||
|
save_strategy="no",
|
||||||
|
run_name="lomo-imdb",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
|
||||||
|
|
||||||
|
trainer = trl.SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
dataset_text_field='text',
|
||||||
|
max_seq_length=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
## Accelerate and Trainer
|
## Accelerate and Trainer
|
||||||
|
|
||||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ from .utils import (
|
|||||||
is_keras_nlp_available,
|
is_keras_nlp_available,
|
||||||
is_levenshtein_available,
|
is_levenshtein_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
|
is_lomo_available,
|
||||||
is_natten_available,
|
is_natten_available,
|
||||||
is_nltk_available,
|
is_nltk_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
@@ -338,6 +339,14 @@ def require_galore_torch(test_case):
|
|||||||
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
|
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_lomo(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
|
||||||
|
https://github.com/OpenLMLab/LOMO
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_cv2(test_case):
|
def require_cv2(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires OpenCV.
|
Decorator marking a test that requires OpenCV.
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ from .utils import (
|
|||||||
is_galore_torch_available,
|
is_galore_torch_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
|
is_lomo_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
@@ -1059,12 +1060,18 @@ class Trainer:
|
|||||||
if "params" in optimizer_kwargs:
|
if "params" in optimizer_kwargs:
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
# to avoid arguments conflicts.
|
# to avoid arguments conflicts.
|
||||||
if "optimizer_dict" in optimizer_kwargs:
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
|
|
||||||
@@ -1382,6 +1389,26 @@ class Trainer:
|
|||||||
|
|
||||||
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
|
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
|
||||||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||||
|
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
|
if not is_lomo_available():
|
||||||
|
raise ImportError(
|
||||||
|
"You need to install `lomo_optim` in order to use LOMO optimizers"
|
||||||
|
" install it with `pip install lomo-optim`"
|
||||||
|
)
|
||||||
|
if not is_accelerate_available("0.30.0"):
|
||||||
|
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")
|
||||||
|
|
||||||
|
from lomo_optim import AdaLomo, Lomo
|
||||||
|
|
||||||
|
if "ada" in args.optim:
|
||||||
|
optimizer_cls = AdaLomo
|
||||||
|
else:
|
||||||
|
optimizer_cls = Lomo
|
||||||
|
|
||||||
|
optimizer_kwargs.update({"model": model})
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||||
return optimizer_cls, optimizer_kwargs
|
return optimizer_cls, optimizer_kwargs
|
||||||
@@ -2045,6 +2072,9 @@ class Trainer:
|
|||||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||||
self.model, self.optimizer, self.lr_scheduler
|
self.model, self.optimizer, self.lr_scheduler
|
||||||
)
|
)
|
||||||
|
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
|
# In this case we are in DDP + LOMO, which should be supported
|
||||||
|
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
self.model = self.model_wrapped = model
|
self.model = self.model_wrapped = model
|
||||||
@@ -2143,7 +2173,6 @@ class Trainer:
|
|||||||
self._globalstep_last_logged = self.state.global_step
|
self._globalstep_last_logged = self.state.global_step
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
grad_norm: Optional[float] = None
|
grad_norm: Optional[float] = None
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||||
|
|
||||||
total_batched_samples = 0
|
total_batched_samples = 0
|
||||||
@@ -2275,8 +2304,8 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
grad_norm = _grad_norm
|
grad_norm = _grad_norm
|
||||||
|
|
||||||
# Optimizer step
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||||
if optimizer_was_run:
|
if optimizer_was_run:
|
||||||
# Delay optimizer scheduling until metrics are generated
|
# Delay optimizer scheduling until metrics are generated
|
||||||
@@ -3229,7 +3258,6 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
@@ -3240,6 +3268,12 @@ class Trainer:
|
|||||||
del inputs
|
del inputs
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# For LOMO optimizers you need to explicitly use the learnign rate
|
||||||
|
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||||
|
kwargs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|
||||||
@@ -3247,7 +3281,7 @@ class Trainer:
|
|||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss, **kwargs)
|
||||||
|
|
||||||
return loss.detach() / self.args.gradient_accumulation_steps
|
return loss.detach() / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
|||||||
@@ -171,6 +171,8 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||||
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||||
|
LOMO = "lomo"
|
||||||
|
ADALOMO = "adalomo"
|
||||||
|
|
||||||
|
|
||||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ from .import_utils import (
|
|||||||
is_keras_nlp_available,
|
is_keras_nlp_available,
|
||||||
is_levenshtein_available,
|
is_levenshtein_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
|
is_lomo_available,
|
||||||
is_mlx_available,
|
is_mlx_available,
|
||||||
is_natten_available,
|
is_natten_available,
|
||||||
is_ninja_available,
|
is_ninja_available,
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ _av_available = importlib.util.find_spec("av") is not None
|
|||||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||||
_eetq_available = _is_package_available("eetq")
|
_eetq_available = _is_package_available("eetq")
|
||||||
_galore_torch_available = _is_package_available("galore_torch")
|
_galore_torch_available = _is_package_available("galore_torch")
|
||||||
|
_lomo_available = _is_package_available("lomo_optim")
|
||||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||||
@@ -328,6 +329,10 @@ def is_galore_torch_available():
|
|||||||
return _galore_torch_available
|
return _galore_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_lomo_available():
|
||||||
|
return _lomo_available
|
||||||
|
|
||||||
|
|
||||||
def is_pyctcdecode_available():
|
def is_pyctcdecode_available():
|
||||||
return _pyctcdecode_available
|
return _pyctcdecode_available
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from transformers.testing_utils import (
|
|||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_galore_torch,
|
require_galore_torch,
|
||||||
require_intel_extension_for_pytorch,
|
require_intel_extension_for_pytorch,
|
||||||
|
require_lomo,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
require_peft,
|
require_peft,
|
||||||
require_ray,
|
require_ray,
|
||||||
@@ -1229,6 +1230,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
trainer.evaluate()
|
trainer.evaluate()
|
||||||
|
|
||||||
|
@require_lomo
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_lomo(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
|
||||||
|
previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}
|
||||||
|
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
for name, param in tiny_llama.named_parameters():
|
||||||
|
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
|
||||||
|
|
||||||
|
@require_lomo
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_adalomo(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="adalomo",
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
def test_galore_matched_modules(self):
|
def test_galore_matched_modules(self):
|
||||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user