Optim: APOLLO optimizer integration (#36062)

* Added APOLLO optimizer integration

* fix comment

* Remove redundancy: Modularize low-rank optimizer construction

* Remove redundancy: Remove useless comment

* Fix comment: Add typing

* Fix comment: Rewrite apollo desc
This commit is contained in:
zhuHQ
2025-02-12 08:33:43 -06:00
committed by GitHub
parent 2440512723
commit 08c4959a23
7 changed files with 404 additions and 99 deletions

View File

@@ -443,6 +443,97 @@ trainer.train()
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue. Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
### APOLLO
Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.
* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.
You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [APOLLO: SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
First, make sure to install APOLLO from its official repository:
```bash
pip install apollo-torch
```
Then, APOLLO optimizers can be used simply by setting `optim="apollo_adamw"` and specifying `optim_target_modules`.
`optim_target_modules` can be a list of strings, regex or full path corresponding to the target module names you want to adapt.
Currently, only Linear layers are considered to use the APOLLO optimizers, i.e., included in `optim_target_modules,` while the remaining models are still using AdamW.
You can also enable layer-wise APOLLO by appending "layerwise" to the optimizer name (optim="apollo_adamw_layerwise"), the same as layer-wise GaLore. This saves additional memory for gradient by performing weight updates layer by layer.
Below is an end-to-end example script (make sure to `pip install trl datasets`):
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-apollo",
max_steps=100,
per_device_train_batch_size=2,
optim="apollo_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
)
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
You can further customize APOLLOs behavior by passing hyperparameters using `optim_args`.
| Parameter | Description |
|------------------|-------------|
| `rank` | Rank of the auxiliary sub-space used for gradient scaling. <br> **APOLLO (default=256)** → Works well for 1B and 7B models. <br> **APOLLO-Mini (default=1)** |
| `scale_type` | How scaling factors are applied. <br> **`channel`** → Per-channel scaling (used in APOLLO). <br> **`tensor`** → Per-tensor scaling (used in APOLLO-Mini). |
| `scale` | Adjusts gradient updates to stabilize training. <br> **APOLLO (default=1.0)** <br> **APOLLO-Mini (default=128)** |
| `update_proj_gap` | Steps before updating projection matrices. Default: **200**. |
| `proj` | Type of projection. Default: **`random`**. |
<Tip>
The `scale` parameter can be set to `n/r`, where `n` is the original space dimension and `r` is the low-rank space dimension.
Alternatively, you can achieve a similar effect by adjusting the learning rate, while keeping scale at its default value.
</Tip>
For example, you can enable APOLLO-Mini (rank=1 for extreme memory efficiency) by passing `optim_args`:
```python
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="apollo_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="proj=random,rank=1,scale=128.0,scale_type=tensor,update_proj_gap=200",
)
```
### LOMO optimizer ### LOMO optimizer
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195). The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).

View File

@@ -62,6 +62,7 @@ from .utils import (
GGUF_MIN_VERSION, GGUF_MIN_VERSION,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_apollo_torch_available,
is_aqlm_available, is_aqlm_available,
is_auto_awq_available, is_auto_awq_available,
is_auto_gptq_available, is_auto_gptq_available,
@@ -404,6 +405,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case) return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
def require_apollo_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
https://github.com/zhuhanqing/APOLLO
"""
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
def require_lomo(test_case): def require_lomo(test_case):
""" """
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed. Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.

View File

@@ -151,6 +151,7 @@ from .utils import (
find_labels, find_labels,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_apollo_torch_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_datasets_available, is_datasets_available,
is_galore_torch_available, is_galore_torch_available,
@@ -1315,6 +1316,103 @@ class Trainer:
"betas": (args.adam_beta1, args.adam_beta2), "betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon, "eps": args.adam_epsilon,
} }
def setup_low_rank_optimizer(
optimizer_name: str,
optimizer_mapping: Dict[str, Any],
optim_kwargs: Dict[str, Any],
is_layerwise_supported: bool = True,
) -> Tuple[Any, Any]:
"""
Helper function to set up low-rank optimizers like GaLore and Apollo.
Args:
optimizer_name (str): Name of the optimizer.
optimizer_mapping (dict): Mapping of optimizer names to their classes.
optim_kwargs (dict): Keyword arguments for the optimizer.
is_layerwise_supported (bool): Whether layerwise optimization is supported.
Returns:
Tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
"""
is_layerwise = optimizer_name.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
optimizer_cls = optimizer_mapping[optimizer_name]
if args.optim_target_modules is None:
raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
)
if model is None:
raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)
target_params = []
target_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)
if not isinstance(module, nn.Linear):
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
)
continue
if not target_module_exists and not all_linear:
continue
target_params.append(module.weight)
target_params_names.append(module_name + ".weight")
if len(target_params) == 0:
raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
optim_kwargs.update(optim_args)
param_groups = [
{"params": non_target_params},
{"params": target_params, **optim_kwargs},
]
if is_layerwise:
if args.gradient_accumulation_steps != 1:
raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
optimizer_dict = {}
for param in non_target_params:
optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
for param in target_params:
optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
optimizer_kwargs.update({"params": param_groups})
return optimizer_cls, optimizer_kwargs
if args.optim == OptimizerNames.ADAFACTOR: if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
@@ -1476,10 +1574,6 @@ class Trainer:
) )
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
optimizer_mapping = { optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW, OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
@@ -1489,59 +1583,6 @@ class Trainer:
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
} }
optimizer_cls = optimizer_mapping[args.optim]
if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)
if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)
if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)
all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)
galore_params = []
galore_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)
if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
)
continue
if not target_module_exists and not all_linear:
continue
galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")
if len(galore_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
galore_optim_kwargs = { galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)), "rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
@@ -1549,45 +1590,39 @@ class Trainer:
"proj_type": optim_args.pop("proj_type", "std"), "proj_type": optim_args.pop("proj_type", "std"),
} }
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
param_groups = [ args.optim, optimizer_mapping, galore_optim_kwargs
{"params": non_galore_params}, )
{"params": galore_params, **galore_optim_kwargs},
]
if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
optimizer_dict = {}
for param in non_galore_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
optimizer_kwargs.update({"params": param_groups})
if args.optim == OptimizerNames.GALORE_ADAFACTOR: if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim in [
OptimizerNames.APOLLO_ADAMW,
OptimizerNames.APOLLO_ADAMW_LAYERWISE,
]:
if not is_apollo_torch_available():
raise ImportError(
"You need to install `apollo_torch` in order to use APOLLO optimizers"
" install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
)
from apollo_torch import APOLLOAdamW
optimizer_mapping = {
OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
}
apollo_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"proj": optim_args.pop("proj", "random"),
"scale_type": optim_args.pop("scale_type", "channel"),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 1.0)),
"proj_type": optim_args.pop("proj_type", "std"),
}
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
args.optim, optimizer_mapping, apollo_optim_kwargs
)
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available(): if not is_lomo_available():
raise ImportError( raise ImportError(

View File

@@ -185,6 +185,8 @@ class OptimizerNames(ExplicitEnum):
SCHEDULE_FREE_RADAM = "schedule_free_radam" SCHEDULE_FREE_RADAM = "schedule_free_radam"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw" SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd" SCHEDULE_FREE_SGD = "schedule_free_sgd"
APOLLO_ADAMW = "apollo_adamw"
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
# Sometimes users will pass in a `str` repr of a dict in the CLI # Sometimes users will pass in a `str` repr of a dict in the CLI
@@ -790,11 +792,10 @@ class TrainingArguments:
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft. The original paper used values in the range [5.0, 15.0]. `PeftModel` from peft. The original paper used values in the range [5.0, 15.0].
optim_target_modules (`Union[str, List[str]]`, *optional*): optim_target_modules (`Union[str, List[str]]`, *optional*):
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm The target modules to optimize, i.e. the module names that you would like to train.
https://arxiv.org/abs/2403.03507 Currently used for the GaLore algorithm (https://arxiv.org/abs/2403.03507) and APOLLO algorithm (https://arxiv.org/abs/2412.05270).
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
only.
batch_eval_metrics (`Optional[bool]`, defaults to `False`): batch_eval_metrics (`Optional[bool]`, defaults to `False`):
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics

View File

@@ -117,6 +117,7 @@ from .import_utils import (
get_torch_version, get_torch_version,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_apollo_torch_available,
is_aqlm_available, is_aqlm_available,
is_auto_awq_available, is_auto_awq_available,
is_auto_gptq_available, is_auto_gptq_available,

View File

@@ -99,6 +99,7 @@ VPTQ_MIN_VERSION = "0.0.4"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex") _apex_available = _is_package_available("apex")
_apollo_torch_available = _is_package_available("apollo_torch")
_aqlm_available = _is_package_available("aqlm") _aqlm_available = _is_package_available("aqlm")
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) _vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
_av_available = importlib.util.find_spec("av") is not None _av_available = importlib.util.find_spec("av") is not None
@@ -403,6 +404,10 @@ def is_galore_torch_available():
return _galore_torch_available return _galore_torch_available
def is_apollo_torch_available():
return _apollo_torch_available
def is_lomo_available(): def is_lomo_available():
return _lomo_available return _lomo_available

View File

@@ -66,6 +66,7 @@ from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
is_staging_test, is_staging_test,
require_accelerate, require_accelerate,
require_apollo_torch,
require_bitsandbytes, require_bitsandbytes,
require_deepspeed, require_deepspeed,
require_galore_torch, require_galore_torch,
@@ -2259,6 +2260,168 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# warm up steps << total steps # warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs)) self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_apollo_torch
@require_torch_gpu
def test_apollo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="apollo_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
def test_apollo_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="apollo_adamw",
optim_args="proj=random,scale_type=tensor,rank=1,update_proj_gap=100,scale=128.0",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
def test_apollo_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="apollo_adamw_layerwise",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
def test_apollo_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="apollo_adamw_layerwise",
lr_scheduler_type="cosine",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
def test_apollo_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
learning_rate = 1e-9
num_steps = 10
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=learning_rate,
logging_steps=5,
optim="apollo_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
# reflects displayed lr in trainer
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
@require_apollo_torch
@require_torch_gpu
def test_apollo_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
learning_rate = 2e-4
num_train_epochs = 10
num_warmup_steps = 5
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
warmup_steps=num_warmup_steps,
lr_scheduler_type="cosine",
logging_steps=1,
optim="apollo_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# creating log history of trainer, results don't matter
trainer.train()
logs = trainer.state.log_history[1:][:-1]
# reach given learning rate peak and end with 0 lr
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
self.assertTrue(logs[-1]["learning_rate"] == 0)
# increasing and decreasing pattern of lrs
increasing_lrs = [
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
for i in range(len(logs))
if i < num_warmup_steps - 2
]
decreasing_lrs = [
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
for i in range(len(logs) - 1)
if i >= num_warmup_steps - 2
]
self.assertTrue(all(increasing_lrs))
self.assertTrue(all(decreasing_lrs))
# warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self): def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel() model = RegressionModel()