FEAT / Optim: Add GaLore optimizer (#29588)
* add galore v1 * add import * add tests and doc * fix doctest * forward contrib credits from discussions * forward contrib credits from discussions * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix failing tests' * switch to `optim_target_modules` and clarify docs * more clarification * enhance lookup logic * update a test to add peak memory * add regex, all-linear and single string support * add layer-wise optimization through DummyOptimizers and LRSchedulers * forward contrib credits from discussions and original idea * add a section about DDP not supported in layerwise * Update src/transformers/trainer.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix self * check only if layer_wise * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * oops * make use of intervals * clarify comment * add matching tests * GaLoRe -> GaLore * move to `get_scheduler` * add note on docs * add a warning * adapt a bit the docs * update docstring * support original API * Update docs/source/en/trainer.md * slightly refactor * Update docs/source/en/trainer.md Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix args parsing and add tests * remove warning for regex * fix type hint * add note about extra args * make `is_regex` return optional --------- Co-authored-by: Maxime <maximegmd @users.noreply.github.com> Co-authored-by: Wing Lian <winglian @users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: hiyouga <hiyouga@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
@@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args)
|
||||
|
||||
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.
|
||||
|
||||
## GaLore
|
||||
|
||||
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.
|
||||
|
||||
First make sure to install GaLore official repository:
|
||||
|
||||
```bash
|
||||
pip install galore-torch
|
||||
```
|
||||
|
||||
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
import trl
|
||||
|
||||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-galore",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="galore_adamw",
|
||||
optim_target_modules=["attn", "mlp"]
|
||||
)
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_config(config).to(0)
|
||||
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=512,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
import trl
|
||||
|
||||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-galore",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="galore_adamw",
|
||||
optim_target_modules=["attn", "mlp"],
|
||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||
)
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_config(config).to(0)
|
||||
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=512,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).
|
||||
|
||||
Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.
|
||||
|
||||
Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.
|
||||
|
||||
You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
import trl
|
||||
|
||||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-galore",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="galore_adamw_layerwise",
|
||||
optim_target_modules=["attn", "mlp"]
|
||||
)
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_config(config).to(0)
|
||||
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=512,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
|
||||
|
||||
## Accelerate and Trainer
|
||||
|
||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
||||
|
||||
from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
|
||||
from .trainer_utils import SchedulerType
|
||||
from .utils import logging
|
||||
from .utils.versions import require_version
|
||||
@@ -362,6 +363,32 @@ def get_scheduler(
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
|
||||
# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
|
||||
# recursively call `get_scheduler` to get the proper schedulers on each parameter
|
||||
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
|
||||
optimizer_dict = optimizer.optimizer_dict
|
||||
scheduler_dict = {}
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
name,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
def scheduler_hook(param):
|
||||
# Since the optimizer hook has been already attached we only need to
|
||||
# attach the scheduler hook
|
||||
if param.grad is not None:
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||
|
||||
return LayerWiseDummyScheduler()
|
||||
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ from .utils import (
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
is_g2p_en_available,
|
||||
is_galore_torch_available,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
is_jinja_available,
|
||||
@@ -325,6 +326,14 @@ def require_bs4(test_case):
|
||||
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
|
||||
|
||||
|
||||
def require_galore_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
|
||||
https://github.com/jiaweizzhao/GaLore
|
||||
"""
|
||||
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
|
||||
|
||||
|
||||
def require_cv2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires OpenCV.
|
||||
|
||||
@@ -83,6 +83,7 @@ from .trainer_pt_utils import (
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LayerWiseDummyOptimizer,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
distributed_broadcast_scalars,
|
||||
@@ -111,6 +112,7 @@ from .trainer_utils import (
|
||||
RemoveColumnsCollator,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
check_target_module_exists,
|
||||
default_compute_objective,
|
||||
denumpify_detensorize,
|
||||
enable_full_determinism,
|
||||
@@ -141,6 +143,7 @@ from .utils import (
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_datasets_available,
|
||||
is_galore_torch_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_peft_available,
|
||||
@@ -1010,7 +1013,17 @@ class Trainer:
|
||||
},
|
||||
]
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
||||
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
||||
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
@@ -1033,7 +1046,9 @@ class Trainer:
|
||||
return self.optimizer
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
|
||||
def get_optimizer_cls_and_kwargs(
|
||||
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Returns the optimizer class and optimizer parameters based on the training arguments.
|
||||
|
||||
@@ -1171,6 +1186,132 @@ class Trainer:
|
||||
optimizer_cls = torch.optim.Adagrad
|
||||
elif args.optim == OptimizerNames.RMSPROP:
|
||||
optimizer_cls = torch.optim.RMSprop
|
||||
elif args.optim in [
|
||||
OptimizerNames.GALORE_ADAMW,
|
||||
OptimizerNames.GALORE_ADAMW_8BIT,
|
||||
OptimizerNames.GALORE_ADAFACTOR,
|
||||
OptimizerNames.GALORE_ADAMW_LAYERWISE,
|
||||
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
|
||||
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
|
||||
]:
|
||||
if not is_galore_torch_available():
|
||||
raise ImportError(
|
||||
"You need to install `galore_torch` in order to use GaLore optimizers"
|
||||
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
|
||||
)
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
|
||||
is_layerwise = args.optim.lower().endswith("layerwise")
|
||||
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
|
||||
|
||||
optimizer_mapping = {
|
||||
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
|
||||
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
|
||||
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
|
||||
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
|
||||
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
|
||||
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
|
||||
}
|
||||
|
||||
optimizer_cls = optimizer_mapping[args.optim]
|
||||
|
||||
if args.optim_target_modules is None:
|
||||
raise ValueError(
|
||||
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
|
||||
)
|
||||
|
||||
if not isinstance(args.optim_target_modules, (list, str)):
|
||||
raise ValueError(
|
||||
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
|
||||
|
||||
logger.warning(
|
||||
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
|
||||
)
|
||||
|
||||
all_linear = (
|
||||
isinstance(args.optim_target_modules, str)
|
||||
and args.optim_target_modules.replace("_", "-") == "all-linear"
|
||||
)
|
||||
|
||||
galore_params = []
|
||||
galore_params_names = []
|
||||
for module_name, module in model.named_modules():
|
||||
target_module_exists, is_regex = check_target_module_exists(
|
||||
args.optim_target_modules, module_name, return_is_regex=True
|
||||
)
|
||||
|
||||
if not isinstance(module, nn.Linear):
|
||||
# Warn in case we match but it's not a linear layer
|
||||
if target_module_exists and not is_regex:
|
||||
logger.warning(
|
||||
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
if not target_module_exists and not all_linear:
|
||||
continue
|
||||
|
||||
galore_params.append(module.weight)
|
||||
galore_params_names.append(module_name + ".weight")
|
||||
|
||||
if len(galore_params) == 0:
|
||||
raise ValueError(
|
||||
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
|
||||
)
|
||||
|
||||
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
|
||||
|
||||
galore_optim_kwargs = {
|
||||
"rank": int(optim_args.pop("rank", 128)),
|
||||
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
|
||||
"scale": float(optim_args.pop("scale", 0.25)),
|
||||
"proj_type": optim_args.pop("proj_type", "std"),
|
||||
}
|
||||
|
||||
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{"params": galore_params, **galore_optim_kwargs},
|
||||
]
|
||||
|
||||
if is_layerwise:
|
||||
# For layer-wise optimizers, the optimization step is done through post accumulation
|
||||
# gradient hooks. The trick is to first attach these hooks to the model parameters then
|
||||
# create a dummy optimizer that will perform no-ops in the Trainer.
|
||||
# See the original implementation or the nice implementation from @hiyouga
|
||||
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
|
||||
if args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
|
||||
|
||||
optimizer_dict = {}
|
||||
for param in non_galore_params:
|
||||
param_groups = [{"params": [param]}]
|
||||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||
for param in galore_params:
|
||||
param_groups = [{"params": [param], **galore_optim_kwargs}]
|
||||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||
|
||||
def optimizer_hook(param):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
for param in model.parameters():
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer_cls = LayerWiseDummyOptimizer
|
||||
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
|
||||
|
||||
optimizer_kwargs.update({"params": param_groups})
|
||||
|
||||
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
|
||||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||
else:
|
||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
||||
@@ -34,6 +34,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
@@ -1226,3 +1227,47 @@ class AcceleratorConfig:
|
||||
|
||||
def to_dict(self):
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
||||
|
||||
class LayerWiseDummyOptimizer(torch.optim.Optimizer):
|
||||
"""
|
||||
For Layer-wise optimizers such as GaLoRE optimizer, the optimization
|
||||
step is already done through the post gradient hooks. Therefore
|
||||
the trick is to create a dummy optimizer that can take arbitrary
|
||||
args and kwargs and return a no-op during training.
|
||||
|
||||
Initial idea from @hiyouga in LLaMA-Factory:
|
||||
https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_dict=None, *args, **kwargs):
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
self.optimizer_dict = optimizer_dict
|
||||
super().__init__([dummy_tensor], {"lr": 1e-03})
|
||||
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
pass
|
||||
|
||||
def step(self, closure=None) -> Optional[float]:
|
||||
pass
|
||||
|
||||
|
||||
class LayerWiseDummyScheduler(LRScheduler):
|
||||
"""
|
||||
For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step
|
||||
are already done through the post gradient hooks. Therefore
|
||||
the trick is to create a dummy scheduler that can take arbitrary
|
||||
args and kwargs and return a no-op during training.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
optimizer = LayerWiseDummyOptimizer()
|
||||
last_epoch = -1
|
||||
verbose = False
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
return self.base_lrs
|
||||
|
||||
@@ -785,3 +785,42 @@ class RemoveColumnsCollator:
|
||||
def __call__(self, features: List[dict]):
|
||||
features = [self._remove_columns(feature) for feature in features]
|
||||
return self.data_collator(features)
|
||||
|
||||
|
||||
def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
|
||||
"""A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
|
||||
|
||||
Args:
|
||||
optim_target_modules (`Union[str, List[str]]`):
|
||||
A list of strings to try to match. Can be also a full string.
|
||||
key (`str`):
|
||||
A key to search any matches in optim_target_modules
|
||||
return_is_regex (`bool`):
|
||||
If set to `True`, the method will return whether the passed `optim_target_modules`
|
||||
is a regex or not.
|
||||
|
||||
Returns:
|
||||
`bool` : True of match object if key matches any target modules from config, False or
|
||||
None if no match found
|
||||
`bool` : If the matched target module is a regex to silence out the warnings in Trainer
|
||||
for extra modules being found (only if `target_module_found=True` for an array of regex).
|
||||
"""
|
||||
target_module_found = False
|
||||
is_regex = False
|
||||
|
||||
if isinstance(optim_target_modules, str):
|
||||
target_module_found = bool(re.fullmatch(optim_target_modules, key))
|
||||
is_regex = True if not optim_target_modules == key else False
|
||||
elif key in optim_target_modules: # from here, target_module_found must be a list of str
|
||||
# this module is specified directly in target_modules
|
||||
target_module_found = True
|
||||
elif any(target_key in key for target_key in optim_target_modules):
|
||||
target_module_found = True
|
||||
elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
|
||||
target_module_found = True
|
||||
is_regex = True
|
||||
|
||||
if return_is_regex:
|
||||
return target_module_found, is_regex
|
||||
|
||||
return target_module_found
|
||||
|
||||
@@ -164,6 +164,12 @@ class OptimizerNames(ExplicitEnum):
|
||||
RMSPROP_BNB = "rmsprop_bnb"
|
||||
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
||||
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
||||
GALORE_ADAMW = "galore_adamw"
|
||||
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
|
||||
GALORE_ADAFACTOR = "galore_adafactor"
|
||||
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||
|
||||
|
||||
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
||||
@@ -696,6 +702,12 @@ class TrainingArguments:
|
||||
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
|
||||
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
|
||||
`PeftModel` from peft.
|
||||
optim_target_modules (`Union[str, List[str]]`, *optional*):
|
||||
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
||||
https://arxiv.org/abs/2403.03507
|
||||
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
|
||||
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
|
||||
only.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
@@ -1354,6 +1366,13 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
optim_target_modules: Union[None, str, List[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||
# in the current directory instead of the actual home
|
||||
|
||||
@@ -125,6 +125,7 @@ from .import_utils import (
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
is_g2p_en_available,
|
||||
is_galore_torch_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
|
||||
@@ -95,6 +95,7 @@ _accelerate_available, _accelerate_version = _is_package_available("accelerate",
|
||||
_apex_available = _is_package_available("apex")
|
||||
_aqlm_available = _is_package_available("aqlm")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_galore_torch_available = _is_package_available("galore_torch")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||
@@ -309,6 +310,10 @@ def is_torchvision_available():
|
||||
return _torchvision_available
|
||||
|
||||
|
||||
def is_galore_torch_available():
|
||||
return _galore_torch_available
|
||||
|
||||
|
||||
def is_pyctcdecode_available():
|
||||
return _pyctcdecode_available
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_optuna,
|
||||
require_peft,
|
||||
@@ -84,7 +85,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -114,6 +115,8 @@ if is_torch_available():
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
LineByLineTextDataset,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerState,
|
||||
@@ -146,6 +149,31 @@ class RegressionDataset:
|
||||
return result
|
||||
|
||||
|
||||
# Converting Bytes to Megabytes
|
||||
def bytes2megabytes(x):
|
||||
return int(x / 2**20)
|
||||
|
||||
|
||||
# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68
|
||||
class TorchTracemalloc:
|
||||
def __enter__(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
|
||||
self.begin = torch.cuda.memory_allocated()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
self.end = torch.cuda.memory_allocated()
|
||||
self.peak = torch.cuda.max_memory_allocated()
|
||||
self.used = bytes2megabytes(self.end - self.begin)
|
||||
self.peaked = bytes2megabytes(self.peak - self.begin)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegressionTrainingArguments(TrainingArguments):
|
||||
a: float = 0.0
|
||||
@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertTrue(is_regex)
|
||||
|
||||
exact_patterns = ["q_proj", "up_proj"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
simple_regex = r".*.attn.*"
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, False]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertTrue(is_regex)
|
||||
|
||||
simple_regex = "model.transformer.h.0.attn.q_proj"
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, False]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
target_modules = ["attn", "mlp"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_extra_args(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw",
|
||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_layerwise(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw_layerwise",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_layerwise_with_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw_layerwise",
|
||||
lr_scheduler_type="cosine",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adamw_8bit(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw_8bit",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adafactor(self):
|
||||
# These are the intervals of the peak memory usage of training such a tiny model
|
||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adafactor_attention_only(self):
|
||||
# These are the intervals of the peak memory usage of training such a tiny model
|
||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules=["q_proj", "k_proj", "v_proj"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adafactor_all_linear(self):
|
||||
# These are the intervals of the peak memory usage of training such a tiny model
|
||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules="all-linear",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||
model = RegressionModel()
|
||||
|
||||
Reference in New Issue
Block a user