FEAT / Optim: Add GaLore optimizer (#29588)
* add galore v1 * add import * add tests and doc * fix doctest * forward contrib credits from discussions * forward contrib credits from discussions * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix failing tests' * switch to `optim_target_modules` and clarify docs * more clarification * enhance lookup logic * update a test to add peak memory * add regex, all-linear and single string support * add layer-wise optimization through DummyOptimizers and LRSchedulers * forward contrib credits from discussions and original idea * add a section about DDP not supported in layerwise * Update src/transformers/trainer.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix self * check only if layer_wise * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * oops * make use of intervals * clarify comment * add matching tests * GaLoRe -> GaLore * move to `get_scheduler` * add note on docs * add a warning * adapt a bit the docs * update docstring * support original API * Update docs/source/en/trainer.md * slightly refactor * Update docs/source/en/trainer.md Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix args parsing and add tests * remove warning for regex * fix type hint * add note about extra args * make `is_regex` return optional --------- Co-authored-by: Maxime <maximegmd @users.noreply.github.com> Co-authored-by: Wing Lian <winglian @users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: hiyouga <hiyouga@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
@@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args)
|
|||||||
|
|
||||||
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.
|
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.
|
||||||
|
|
||||||
|
## GaLore
|
||||||
|
|
||||||
|
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.
|
||||||
|
|
||||||
|
First make sure to install GaLore official repository:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install galore-torch
|
||||||
|
```
|
||||||
|
|
||||||
|
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import trl
|
||||||
|
|
||||||
|
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir="./test-galore",
|
||||||
|
max_steps=100,
|
||||||
|
per_device_train_batch_size=2,
|
||||||
|
optim="galore_adamw",
|
||||||
|
optim_target_modules=["attn", "mlp"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_config(config).to(0)
|
||||||
|
|
||||||
|
trainer = trl.SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
dataset_text_field='text',
|
||||||
|
max_seq_length=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import trl
|
||||||
|
|
||||||
|
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir="./test-galore",
|
||||||
|
max_steps=100,
|
||||||
|
per_device_train_batch_size=2,
|
||||||
|
optim="galore_adamw",
|
||||||
|
optim_target_modules=["attn", "mlp"],
|
||||||
|
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_config(config).to(0)
|
||||||
|
|
||||||
|
trainer = trl.SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
dataset_text_field='text',
|
||||||
|
max_seq_length=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).
|
||||||
|
|
||||||
|
Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.
|
||||||
|
|
||||||
|
Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.
|
||||||
|
|
||||||
|
You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import trl
|
||||||
|
|
||||||
|
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir="./test-galore",
|
||||||
|
max_steps=100,
|
||||||
|
per_device_train_batch_size=2,
|
||||||
|
optim="galore_adamw_layerwise",
|
||||||
|
optim_target_modules=["attn", "mlp"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_config(config).to(0)
|
||||||
|
|
||||||
|
trainer = trl.SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
dataset_text_field='text',
|
||||||
|
max_seq_length=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
|
||||||
|
|
||||||
## Accelerate and Trainer
|
## Accelerate and Trainer
|
||||||
|
|
||||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from torch import nn
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
||||||
|
|
||||||
|
from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
|
||||||
from .trainer_utils import SchedulerType
|
from .trainer_utils import SchedulerType
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
from .utils.versions import require_version
|
from .utils.versions import require_version
|
||||||
@@ -362,6 +363,32 @@ def get_scheduler(
|
|||||||
"""
|
"""
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
|
||||||
|
# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
|
||||||
|
# recursively call `get_scheduler` to get the proper schedulers on each parameter
|
||||||
|
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
|
||||||
|
optimizer_dict = optimizer.optimizer_dict
|
||||||
|
scheduler_dict = {}
|
||||||
|
|
||||||
|
for param in optimizer_dict.keys():
|
||||||
|
scheduler_dict[param] = get_scheduler(
|
||||||
|
name,
|
||||||
|
optimizer=optimizer_dict[param],
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def scheduler_hook(param):
|
||||||
|
# Since the optimizer hook has been already attached we only need to
|
||||||
|
# attach the scheduler hook
|
||||||
|
if param.grad is not None:
|
||||||
|
scheduler_dict[param].step()
|
||||||
|
|
||||||
|
for param in optimizer_dict.keys():
|
||||||
|
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||||
|
|
||||||
|
return LayerWiseDummyScheduler()
|
||||||
|
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT:
|
||||||
return schedule_func(optimizer)
|
return schedule_func(optimizer)
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ from .utils import (
|
|||||||
is_fsdp_available,
|
is_fsdp_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
is_g2p_en_available,
|
is_g2p_en_available,
|
||||||
|
is_galore_torch_available,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jieba_available,
|
is_jieba_available,
|
||||||
is_jinja_available,
|
is_jinja_available,
|
||||||
@@ -325,6 +326,14 @@ def require_bs4(test_case):
|
|||||||
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
|
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_galore_torch(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
|
||||||
|
https://github.com/jiaweizzhao/GaLore
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_cv2(test_case):
|
def require_cv2(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires OpenCV.
|
Decorator marking a test that requires OpenCV.
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ from .trainer_pt_utils import (
|
|||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
IterableDatasetShard,
|
IterableDatasetShard,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
|
LayerWiseDummyOptimizer,
|
||||||
LengthGroupedSampler,
|
LengthGroupedSampler,
|
||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
@@ -111,6 +112,7 @@ from .trainer_utils import (
|
|||||||
RemoveColumnsCollator,
|
RemoveColumnsCollator,
|
||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
|
check_target_module_exists,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
denumpify_detensorize,
|
denumpify_detensorize,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
@@ -141,6 +143,7 @@ from .utils import (
|
|||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
|
is_galore_torch_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
@@ -1010,7 +1013,17 @@ class Trainer:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
||||||
|
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
@@ -1033,7 +1046,9 @@ class Trainer:
|
|||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
|
def get_optimizer_cls_and_kwargs(
|
||||||
|
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Returns the optimizer class and optimizer parameters based on the training arguments.
|
Returns the optimizer class and optimizer parameters based on the training arguments.
|
||||||
|
|
||||||
@@ -1171,6 +1186,132 @@ class Trainer:
|
|||||||
optimizer_cls = torch.optim.Adagrad
|
optimizer_cls = torch.optim.Adagrad
|
||||||
elif args.optim == OptimizerNames.RMSPROP:
|
elif args.optim == OptimizerNames.RMSPROP:
|
||||||
optimizer_cls = torch.optim.RMSprop
|
optimizer_cls = torch.optim.RMSprop
|
||||||
|
elif args.optim in [
|
||||||
|
OptimizerNames.GALORE_ADAMW,
|
||||||
|
OptimizerNames.GALORE_ADAMW_8BIT,
|
||||||
|
OptimizerNames.GALORE_ADAFACTOR,
|
||||||
|
OptimizerNames.GALORE_ADAMW_LAYERWISE,
|
||||||
|
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
|
||||||
|
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
|
||||||
|
]:
|
||||||
|
if not is_galore_torch_available():
|
||||||
|
raise ImportError(
|
||||||
|
"You need to install `galore_torch` in order to use GaLore optimizers"
|
||||||
|
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
|
||||||
|
)
|
||||||
|
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||||
|
|
||||||
|
is_layerwise = args.optim.lower().endswith("layerwise")
|
||||||
|
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
|
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
|
||||||
|
|
||||||
|
optimizer_mapping = {
|
||||||
|
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
|
||||||
|
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
|
||||||
|
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
|
||||||
|
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
|
||||||
|
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
|
||||||
|
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizer_cls = optimizer_mapping[args.optim]
|
||||||
|
|
||||||
|
if args.optim_target_modules is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(args.optim_target_modules, (list, str)):
|
||||||
|
raise ValueError(
|
||||||
|
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_linear = (
|
||||||
|
isinstance(args.optim_target_modules, str)
|
||||||
|
and args.optim_target_modules.replace("_", "-") == "all-linear"
|
||||||
|
)
|
||||||
|
|
||||||
|
galore_params = []
|
||||||
|
galore_params_names = []
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
target_module_exists, is_regex = check_target_module_exists(
|
||||||
|
args.optim_target_modules, module_name, return_is_regex=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(module, nn.Linear):
|
||||||
|
# Warn in case we match but it's not a linear layer
|
||||||
|
if target_module_exists and not is_regex:
|
||||||
|
logger.warning(
|
||||||
|
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
|
||||||
|
)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not target_module_exists and not all_linear:
|
||||||
|
continue
|
||||||
|
|
||||||
|
galore_params.append(module.weight)
|
||||||
|
galore_params_names.append(module_name + ".weight")
|
||||||
|
|
||||||
|
if len(galore_params) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
|
||||||
|
)
|
||||||
|
|
||||||
|
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
|
||||||
|
|
||||||
|
galore_optim_kwargs = {
|
||||||
|
"rank": int(optim_args.pop("rank", 128)),
|
||||||
|
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
|
||||||
|
"scale": float(optim_args.pop("scale", 0.25)),
|
||||||
|
"proj_type": optim_args.pop("proj_type", "std"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
|
||||||
|
param_groups = [
|
||||||
|
{"params": non_galore_params},
|
||||||
|
{"params": galore_params, **galore_optim_kwargs},
|
||||||
|
]
|
||||||
|
|
||||||
|
if is_layerwise:
|
||||||
|
# For layer-wise optimizers, the optimization step is done through post accumulation
|
||||||
|
# gradient hooks. The trick is to first attach these hooks to the model parameters then
|
||||||
|
# create a dummy optimizer that will perform no-ops in the Trainer.
|
||||||
|
# See the original implementation or the nice implementation from @hiyouga
|
||||||
|
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
|
||||||
|
if args.gradient_accumulation_steps != 1:
|
||||||
|
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
|
||||||
|
|
||||||
|
optimizer_dict = {}
|
||||||
|
for param in non_galore_params:
|
||||||
|
param_groups = [{"params": [param]}]
|
||||||
|
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||||
|
for param in galore_params:
|
||||||
|
param_groups = [{"params": [param], **galore_optim_kwargs}]
|
||||||
|
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||||
|
|
||||||
|
def optimizer_hook(param):
|
||||||
|
if param.grad is not None:
|
||||||
|
optimizer_dict[param].step()
|
||||||
|
optimizer_dict[param].zero_grad()
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||||
|
|
||||||
|
optimizer_cls = LayerWiseDummyOptimizer
|
||||||
|
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
|
||||||
|
|
||||||
|
optimizer_kwargs.update({"params": param_groups})
|
||||||
|
|
||||||
|
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
|
||||||
|
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||||
return optimizer_cls, optimizer_kwargs
|
return optimizer_cls, optimizer_kwargs
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
@@ -1226,3 +1227,47 @@ class AcceleratorConfig:
|
|||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return copy.deepcopy(self.__dict__)
|
return copy.deepcopy(self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerWiseDummyOptimizer(torch.optim.Optimizer):
|
||||||
|
"""
|
||||||
|
For Layer-wise optimizers such as GaLoRE optimizer, the optimization
|
||||||
|
step is already done through the post gradient hooks. Therefore
|
||||||
|
the trick is to create a dummy optimizer that can take arbitrary
|
||||||
|
args and kwargs and return a no-op during training.
|
||||||
|
|
||||||
|
Initial idea from @hiyouga in LLaMA-Factory:
|
||||||
|
https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer_dict=None, *args, **kwargs):
|
||||||
|
dummy_tensor = torch.randn(1, 1)
|
||||||
|
self.optimizer_dict = optimizer_dict
|
||||||
|
super().__init__([dummy_tensor], {"lr": 1e-03})
|
||||||
|
|
||||||
|
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step(self, closure=None) -> Optional[float]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LayerWiseDummyScheduler(LRScheduler):
|
||||||
|
"""
|
||||||
|
For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step
|
||||||
|
are already done through the post gradient hooks. Therefore
|
||||||
|
the trick is to create a dummy scheduler that can take arbitrary
|
||||||
|
args and kwargs and return a no-op during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
optimizer = LayerWiseDummyOptimizer()
|
||||||
|
last_epoch = -1
|
||||||
|
verbose = False
|
||||||
|
super().__init__(optimizer, last_epoch, verbose)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
return [group["lr"] for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
def _get_closed_form_lr(self):
|
||||||
|
return self.base_lrs
|
||||||
|
|||||||
@@ -785,3 +785,42 @@ class RemoveColumnsCollator:
|
|||||||
def __call__(self, features: List[dict]):
|
def __call__(self, features: List[dict]):
|
||||||
features = [self._remove_columns(feature) for feature in features]
|
features = [self._remove_columns(feature) for feature in features]
|
||||||
return self.data_collator(features)
|
return self.data_collator(features)
|
||||||
|
|
||||||
|
|
||||||
|
def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
|
||||||
|
"""A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optim_target_modules (`Union[str, List[str]]`):
|
||||||
|
A list of strings to try to match. Can be also a full string.
|
||||||
|
key (`str`):
|
||||||
|
A key to search any matches in optim_target_modules
|
||||||
|
return_is_regex (`bool`):
|
||||||
|
If set to `True`, the method will return whether the passed `optim_target_modules`
|
||||||
|
is a regex or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`bool` : True of match object if key matches any target modules from config, False or
|
||||||
|
None if no match found
|
||||||
|
`bool` : If the matched target module is a regex to silence out the warnings in Trainer
|
||||||
|
for extra modules being found (only if `target_module_found=True` for an array of regex).
|
||||||
|
"""
|
||||||
|
target_module_found = False
|
||||||
|
is_regex = False
|
||||||
|
|
||||||
|
if isinstance(optim_target_modules, str):
|
||||||
|
target_module_found = bool(re.fullmatch(optim_target_modules, key))
|
||||||
|
is_regex = True if not optim_target_modules == key else False
|
||||||
|
elif key in optim_target_modules: # from here, target_module_found must be a list of str
|
||||||
|
# this module is specified directly in target_modules
|
||||||
|
target_module_found = True
|
||||||
|
elif any(target_key in key for target_key in optim_target_modules):
|
||||||
|
target_module_found = True
|
||||||
|
elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
|
||||||
|
target_module_found = True
|
||||||
|
is_regex = True
|
||||||
|
|
||||||
|
if return_is_regex:
|
||||||
|
return target_module_found, is_regex
|
||||||
|
|
||||||
|
return target_module_found
|
||||||
|
|||||||
@@ -164,6 +164,12 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
RMSPROP_BNB = "rmsprop_bnb"
|
RMSPROP_BNB = "rmsprop_bnb"
|
||||||
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
||||||
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
||||||
|
GALORE_ADAMW = "galore_adamw"
|
||||||
|
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
|
||||||
|
GALORE_ADAFACTOR = "galore_adafactor"
|
||||||
|
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||||
|
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||||
|
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||||
|
|
||||||
|
|
||||||
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
|
||||||
@@ -696,6 +702,12 @@ class TrainingArguments:
|
|||||||
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
|
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
|
||||||
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
|
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
|
||||||
`PeftModel` from peft.
|
`PeftModel` from peft.
|
||||||
|
optim_target_modules (`Union[str, List[str]]`, *optional*):
|
||||||
|
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
||||||
|
https://arxiv.org/abs/2403.03507
|
||||||
|
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
|
||||||
|
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
|
||||||
|
only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
framework = "pt"
|
framework = "pt"
|
||||||
@@ -1354,6 +1366,13 @@ class TrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
optim_target_modules: Union[None, str, List[str]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
# in the current directory instead of the actual home
|
# in the current directory instead of the actual home
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ from .import_utils import (
|
|||||||
is_fsdp_available,
|
is_fsdp_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
is_g2p_en_available,
|
is_g2p_en_available,
|
||||||
|
is_galore_torch_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jieba_available,
|
is_jieba_available,
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ _accelerate_available, _accelerate_version = _is_package_available("accelerate",
|
|||||||
_apex_available = _is_package_available("apex")
|
_apex_available = _is_package_available("apex")
|
||||||
_aqlm_available = _is_package_available("aqlm")
|
_aqlm_available = _is_package_available("aqlm")
|
||||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||||
|
_galore_torch_available = _is_package_available("galore_torch")
|
||||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||||
@@ -309,6 +310,10 @@ def is_torchvision_available():
|
|||||||
return _torchvision_available
|
return _torchvision_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_galore_torch_available():
|
||||||
|
return _galore_torch_available
|
||||||
|
|
||||||
|
|
||||||
def is_pyctcdecode_available():
|
def is_pyctcdecode_available():
|
||||||
return _pyctcdecode_available
|
return _pyctcdecode_available
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ from transformers.testing_utils import (
|
|||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
|
require_galore_torch,
|
||||||
require_intel_extension_for_pytorch,
|
require_intel_extension_for_pytorch,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
require_peft,
|
require_peft,
|
||||||
@@ -84,7 +85,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
@@ -114,6 +115,8 @@ if is_torch_available():
|
|||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
|
LlamaConfig,
|
||||||
|
LlamaForCausalLM,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
@@ -146,6 +149,31 @@ class RegressionDataset:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Converting Bytes to Megabytes
|
||||||
|
def bytes2megabytes(x):
|
||||||
|
return int(x / 2**20)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68
|
||||||
|
class TorchTracemalloc:
|
||||||
|
def __enter__(self):
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
|
||||||
|
self.begin = torch.cuda.memory_allocated()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.end = torch.cuda.memory_allocated()
|
||||||
|
self.peak = torch.cuda.max_memory_allocated()
|
||||||
|
self.used = bytes2megabytes(self.end - self.begin)
|
||||||
|
self.peaked = bytes2megabytes(self.peak - self.begin)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class RegressionTrainingArguments(TrainingArguments):
|
class RegressionTrainingArguments(TrainingArguments):
|
||||||
a: float = 0.0
|
a: float = 0.0
|
||||||
@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
trainer.evaluate()
|
trainer.evaluate()
|
||||||
|
|
||||||
|
def test_galore_matched_modules(self):
|
||||||
|
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||||
|
|
||||||
|
module_names = [
|
||||||
|
"model.transformer.h.0.ln_1",
|
||||||
|
"model.transformer.h.0.attn.q_proj",
|
||||||
|
"model.lm_head",
|
||||||
|
"model.transformer.h.0.mlp.up_proj",
|
||||||
|
]
|
||||||
|
expected_values = [False, True, False, True]
|
||||||
|
|
||||||
|
for expected_value, module_name in zip(expected_values, module_names):
|
||||||
|
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
|
||||||
|
self.assertTrue(is_module_matched == expected_value)
|
||||||
|
if is_module_matched:
|
||||||
|
self.assertTrue(is_regex)
|
||||||
|
|
||||||
|
exact_patterns = ["q_proj", "up_proj"]
|
||||||
|
|
||||||
|
module_names = [
|
||||||
|
"model.transformer.h.0.ln_1",
|
||||||
|
"model.transformer.h.0.attn.q_proj",
|
||||||
|
"model.lm_head",
|
||||||
|
"model.transformer.h.0.mlp.up_proj",
|
||||||
|
]
|
||||||
|
expected_values = [False, True, False, True]
|
||||||
|
|
||||||
|
for expected_value, module_name in zip(expected_values, module_names):
|
||||||
|
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
|
||||||
|
self.assertTrue(is_module_matched == expected_value)
|
||||||
|
if is_module_matched:
|
||||||
|
self.assertFalse(is_regex)
|
||||||
|
|
||||||
|
simple_regex = r".*.attn.*"
|
||||||
|
|
||||||
|
module_names = [
|
||||||
|
"model.transformer.h.0.ln_1",
|
||||||
|
"model.transformer.h.0.attn.q_proj",
|
||||||
|
"model.lm_head",
|
||||||
|
"model.transformer.h.0.mlp.up_proj",
|
||||||
|
]
|
||||||
|
expected_values = [False, True, False, False]
|
||||||
|
|
||||||
|
for expected_value, module_name in zip(expected_values, module_names):
|
||||||
|
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||||
|
self.assertTrue(is_module_matched == expected_value)
|
||||||
|
if is_module_matched:
|
||||||
|
self.assertTrue(is_regex)
|
||||||
|
|
||||||
|
simple_regex = "model.transformer.h.0.attn.q_proj"
|
||||||
|
|
||||||
|
module_names = [
|
||||||
|
"model.transformer.h.0.ln_1",
|
||||||
|
"model.transformer.h.0.attn.q_proj",
|
||||||
|
"model.lm_head",
|
||||||
|
"model.transformer.h.0.mlp.up_proj",
|
||||||
|
]
|
||||||
|
expected_values = [False, True, False, False]
|
||||||
|
|
||||||
|
for expected_value, module_name in zip(expected_values, module_names):
|
||||||
|
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||||
|
self.assertTrue(is_module_matched == expected_value)
|
||||||
|
if is_module_matched:
|
||||||
|
self.assertFalse(is_regex)
|
||||||
|
|
||||||
|
target_modules = ["attn", "mlp"]
|
||||||
|
|
||||||
|
module_names = [
|
||||||
|
"model.transformer.h.0.ln_1",
|
||||||
|
"model.transformer.h.0.attn.q_proj",
|
||||||
|
"model.lm_head",
|
||||||
|
"model.transformer.h.0.mlp.up_proj",
|
||||||
|
]
|
||||||
|
expected_values = [False, True, False, True]
|
||||||
|
|
||||||
|
for expected_value, module_name in zip(expected_values, module_names):
|
||||||
|
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
|
||||||
|
self.assertTrue(is_module_matched == expected_value)
|
||||||
|
if is_module_matched:
|
||||||
|
self.assertFalse(is_regex)
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adamw",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_extra_args(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adamw",
|
||||||
|
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_layerwise(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adamw_layerwise",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_layerwise_with_scheduler(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adamw_layerwise",
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_adamw_8bit(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adamw_8bit",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_adafactor(self):
|
||||||
|
# These are the intervals of the peak memory usage of training such a tiny model
|
||||||
|
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||||
|
upper_bound_pm = 700
|
||||||
|
lower_bound_pm = 650
|
||||||
|
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adafactor",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||||
|
|
||||||
|
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||||
|
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_adafactor_attention_only(self):
|
||||||
|
# These are the intervals of the peak memory usage of training such a tiny model
|
||||||
|
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||||
|
upper_bound_pm = 700
|
||||||
|
lower_bound_pm = 650
|
||||||
|
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adafactor",
|
||||||
|
optim_target_modules=["q_proj", "k_proj", "v_proj"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||||
|
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||||
|
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_adafactor_all_linear(self):
|
||||||
|
# These are the intervals of the peak memory usage of training such a tiny model
|
||||||
|
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||||
|
upper_bound_pm = 700
|
||||||
|
lower_bound_pm = 650
|
||||||
|
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adafactor",
|
||||||
|
optim_target_modules="all-linear",
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||||
|
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||||
|
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||||
|
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||||
model = RegressionModel()
|
model = RegressionModel()
|
||||||
|
|||||||
Reference in New Issue
Block a user