Optim: APOLLO optimizer integration (#36062)
* Added APOLLO optimizer integration * fix comment * Remove redundancy: Modularize low-rank optimizer construction * Remove redundancy: Remove useless comment * Fix comment: Add typing * Fix comment: Rewrite apollo desc
This commit is contained in:
@@ -443,6 +443,97 @@ trainer.train()
|
||||
|
||||
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
|
||||
|
||||
### APOLLO
|
||||
|
||||
Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.
|
||||
|
||||
* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
|
||||
* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.
|
||||
|
||||
You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [APOLLO: SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
||||
|
||||
First, make sure to install APOLLO from its official repository:
|
||||
|
||||
```bash
|
||||
pip install apollo-torch
|
||||
```
|
||||
|
||||
Then, APOLLO optimizers can be used simply by setting `optim="apollo_adamw"` and specifying `optim_target_modules`.
|
||||
`optim_target_modules` can be a list of strings, regex or full path corresponding to the target module names you want to adapt.
|
||||
Currently, only Linear layers are considered to use the APOLLO optimizers, i.e., included in `optim_target_modules,` while the remaining models are still using AdamW.
|
||||
|
||||
|
||||
You can also enable layer-wise APOLLO by appending "layerwise" to the optimizer name (optim="apollo_adamw_layerwise"), the same as layer-wise GaLore. This saves additional memory for gradient by performing weight updates layer by layer.
|
||||
|
||||
Below is an end-to-end example script (make sure to `pip install trl datasets`):
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
import trl
|
||||
|
||||
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-apollo",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
|
||||
)
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
|
||||
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=512,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
|
||||
You can further customize APOLLO’s behavior by passing hyperparameters using `optim_args`.
|
||||
|
||||
| Parameter | Description |
|
||||
|------------------|-------------|
|
||||
| `rank` | Rank of the auxiliary sub-space used for gradient scaling. <br> **APOLLO (default=256)** → Works well for 1B and 7B models. <br> **APOLLO-Mini (default=1)** |
|
||||
| `scale_type` | How scaling factors are applied. <br> **`channel`** → Per-channel scaling (used in APOLLO). <br> **`tensor`** → Per-tensor scaling (used in APOLLO-Mini). |
|
||||
| `scale` | Adjusts gradient updates to stabilize training. <br> **APOLLO (default=1.0)** <br> **APOLLO-Mini (default=128)** |
|
||||
| `update_proj_gap` | Steps before updating projection matrices. Default: **200**. |
|
||||
| `proj` | Type of projection. Default: **`random`**. |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
The `scale` parameter can be set to `n/r`, where `n` is the original space dimension and `r` is the low-rank space dimension.
|
||||
Alternatively, you can achieve a similar effect by adjusting the learning rate, while keeping scale at its default value.
|
||||
|
||||
</Tip>
|
||||
|
||||
For example, you can enable APOLLO-Mini (rank=1 for extreme memory efficiency) by passing `optim_args`:
|
||||
|
||||
```python
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-galore",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
|
||||
optim_args="proj=random,rank=1,scale=128.0,scale_type=tensor,update_proj_gap=200",
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
### LOMO optimizer
|
||||
|
||||
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
|
||||
|
||||
@@ -62,6 +62,7 @@ from .utils import (
|
||||
GGUF_MIN_VERSION,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_apollo_torch_available,
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
@@ -404,6 +405,14 @@ def require_galore_torch(test_case):
|
||||
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
|
||||
|
||||
|
||||
def require_apollo_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
|
||||
https://github.com/zhuhanqing/APOLLO
|
||||
"""
|
||||
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
|
||||
|
||||
|
||||
def require_lomo(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
|
||||
|
||||
@@ -151,6 +151,7 @@ from .utils import (
|
||||
find_labels,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_apollo_torch_available,
|
||||
is_bitsandbytes_available,
|
||||
is_datasets_available,
|
||||
is_galore_torch_available,
|
||||
@@ -1315,6 +1316,103 @@ class Trainer:
|
||||
"betas": (args.adam_beta1, args.adam_beta2),
|
||||
"eps": args.adam_epsilon,
|
||||
}
|
||||
|
||||
def setup_low_rank_optimizer(
|
||||
optimizer_name: str,
|
||||
optimizer_mapping: Dict[str, Any],
|
||||
optim_kwargs: Dict[str, Any],
|
||||
is_layerwise_supported: bool = True,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Helper function to set up low-rank optimizers like GaLore and Apollo.
|
||||
|
||||
Args:
|
||||
optimizer_name (str): Name of the optimizer.
|
||||
optimizer_mapping (dict): Mapping of optimizer names to their classes.
|
||||
optim_kwargs (dict): Keyword arguments for the optimizer.
|
||||
is_layerwise_supported (bool): Whether layerwise optimization is supported.
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
|
||||
"""
|
||||
is_layerwise = optimizer_name.lower().endswith("layerwise")
|
||||
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
|
||||
raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
|
||||
|
||||
optimizer_cls = optimizer_mapping[optimizer_name]
|
||||
|
||||
if args.optim_target_modules is None:
|
||||
raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
|
||||
|
||||
if not isinstance(args.optim_target_modules, (list, str)):
|
||||
raise ValueError(
|
||||
f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
|
||||
|
||||
all_linear = (
|
||||
isinstance(args.optim_target_modules, str)
|
||||
and args.optim_target_modules.replace("_", "-") == "all-linear"
|
||||
)
|
||||
|
||||
target_params = []
|
||||
target_params_names = []
|
||||
for module_name, module in model.named_modules():
|
||||
target_module_exists, is_regex = check_target_module_exists(
|
||||
args.optim_target_modules, module_name, return_is_regex=True
|
||||
)
|
||||
|
||||
if not isinstance(module, nn.Linear):
|
||||
if target_module_exists and not is_regex:
|
||||
logger.warning(
|
||||
f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
|
||||
)
|
||||
continue
|
||||
|
||||
if not target_module_exists and not all_linear:
|
||||
continue
|
||||
|
||||
target_params.append(module.weight)
|
||||
target_params_names.append(module_name + ".weight")
|
||||
|
||||
if len(target_params) == 0:
|
||||
raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
|
||||
|
||||
non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
|
||||
optim_kwargs.update(optim_args)
|
||||
|
||||
param_groups = [
|
||||
{"params": non_target_params},
|
||||
{"params": target_params, **optim_kwargs},
|
||||
]
|
||||
|
||||
if is_layerwise:
|
||||
if args.gradient_accumulation_steps != 1:
|
||||
raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
|
||||
|
||||
optimizer_dict = {}
|
||||
for param in non_target_params:
|
||||
optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
|
||||
for param in target_params:
|
||||
optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
|
||||
|
||||
def optimizer_hook(param):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer_cls = LayerWiseDummyOptimizer
|
||||
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
|
||||
|
||||
optimizer_kwargs.update({"params": param_groups})
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
||||
if args.optim == OptimizerNames.ADAFACTOR:
|
||||
optimizer_cls = Adafactor
|
||||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||
@@ -1476,10 +1574,6 @@ class Trainer:
|
||||
)
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
|
||||
is_layerwise = args.optim.lower().endswith("layerwise")
|
||||
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
|
||||
|
||||
optimizer_mapping = {
|
||||
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
|
||||
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
|
||||
@@ -1489,59 +1583,6 @@ class Trainer:
|
||||
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
|
||||
}
|
||||
|
||||
optimizer_cls = optimizer_mapping[args.optim]
|
||||
|
||||
if args.optim_target_modules is None:
|
||||
raise ValueError(
|
||||
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
|
||||
)
|
||||
|
||||
if not isinstance(args.optim_target_modules, (list, str)):
|
||||
raise ValueError(
|
||||
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
|
||||
|
||||
logger.warning(
|
||||
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
|
||||
)
|
||||
|
||||
all_linear = (
|
||||
isinstance(args.optim_target_modules, str)
|
||||
and args.optim_target_modules.replace("_", "-") == "all-linear"
|
||||
)
|
||||
|
||||
galore_params = []
|
||||
galore_params_names = []
|
||||
for module_name, module in model.named_modules():
|
||||
target_module_exists, is_regex = check_target_module_exists(
|
||||
args.optim_target_modules, module_name, return_is_regex=True
|
||||
)
|
||||
|
||||
if not isinstance(module, nn.Linear):
|
||||
# Warn in case we match but it's not a linear layer
|
||||
if target_module_exists and not is_regex:
|
||||
logger.warning(
|
||||
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
if not target_module_exists and not all_linear:
|
||||
continue
|
||||
|
||||
galore_params.append(module.weight)
|
||||
galore_params_names.append(module_name + ".weight")
|
||||
|
||||
if len(galore_params) == 0:
|
||||
raise ValueError(
|
||||
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
|
||||
)
|
||||
|
||||
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
|
||||
|
||||
galore_optim_kwargs = {
|
||||
"rank": int(optim_args.pop("rank", 128)),
|
||||
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
|
||||
@@ -1549,45 +1590,39 @@ class Trainer:
|
||||
"proj_type": optim_args.pop("proj_type", "std"),
|
||||
}
|
||||
|
||||
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{"params": galore_params, **galore_optim_kwargs},
|
||||
]
|
||||
|
||||
if is_layerwise:
|
||||
# For layer-wise optimizers, the optimization step is done through post accumulation
|
||||
# gradient hooks. The trick is to first attach these hooks to the model parameters then
|
||||
# create a dummy optimizer that will perform no-ops in the Trainer.
|
||||
# See the original implementation or the nice implementation from @hiyouga
|
||||
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
|
||||
if args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
|
||||
|
||||
optimizer_dict = {}
|
||||
for param in non_galore_params:
|
||||
param_groups = [{"params": [param]}]
|
||||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||
for param in galore_params:
|
||||
param_groups = [{"params": [param], **galore_optim_kwargs}]
|
||||
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||
|
||||
def optimizer_hook(param):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer_cls = LayerWiseDummyOptimizer
|
||||
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
|
||||
|
||||
optimizer_kwargs.update({"params": param_groups})
|
||||
|
||||
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
|
||||
args.optim, optimizer_mapping, galore_optim_kwargs
|
||||
)
|
||||
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
|
||||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
|
||||
elif args.optim in [
|
||||
OptimizerNames.APOLLO_ADAMW,
|
||||
OptimizerNames.APOLLO_ADAMW_LAYERWISE,
|
||||
]:
|
||||
if not is_apollo_torch_available():
|
||||
raise ImportError(
|
||||
"You need to install `apollo_torch` in order to use APOLLO optimizers"
|
||||
" install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
|
||||
)
|
||||
from apollo_torch import APOLLOAdamW
|
||||
|
||||
optimizer_mapping = {
|
||||
OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
|
||||
OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
|
||||
}
|
||||
|
||||
apollo_optim_kwargs = {
|
||||
"rank": int(optim_args.pop("rank", 128)),
|
||||
"proj": optim_args.pop("proj", "random"),
|
||||
"scale_type": optim_args.pop("scale_type", "channel"),
|
||||
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
|
||||
"scale": float(optim_args.pop("scale", 1.0)),
|
||||
"proj_type": optim_args.pop("proj_type", "std"),
|
||||
}
|
||||
|
||||
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
|
||||
args.optim, optimizer_mapping, apollo_optim_kwargs
|
||||
)
|
||||
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
if not is_lomo_available():
|
||||
raise ImportError(
|
||||
|
||||
@@ -185,6 +185,8 @@ class OptimizerNames(ExplicitEnum):
|
||||
SCHEDULE_FREE_RADAM = "schedule_free_radam"
|
||||
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
|
||||
SCHEDULE_FREE_SGD = "schedule_free_sgd"
|
||||
APOLLO_ADAMW = "apollo_adamw"
|
||||
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
|
||||
|
||||
|
||||
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
||||
@@ -790,11 +792,10 @@ class TrainingArguments:
|
||||
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
|
||||
`PeftModel` from peft. The original paper used values in the range [5.0, 15.0].
|
||||
optim_target_modules (`Union[str, List[str]]`, *optional*):
|
||||
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
||||
https://arxiv.org/abs/2403.03507
|
||||
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
|
||||
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
|
||||
only.
|
||||
The target modules to optimize, i.e. the module names that you would like to train.
|
||||
Currently used for the GaLore algorithm (https://arxiv.org/abs/2403.03507) and APOLLO algorithm (https://arxiv.org/abs/2412.05270).
|
||||
See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
|
||||
You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
|
||||
|
||||
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
|
||||
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
|
||||
|
||||
@@ -117,6 +117,7 @@ from .import_utils import (
|
||||
get_torch_version,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_apollo_torch_available,
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
|
||||
@@ -99,6 +99,7 @@ VPTQ_MIN_VERSION = "0.0.4"
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_apex_available = _is_package_available("apex")
|
||||
_apollo_torch_available = _is_package_available("apollo_torch")
|
||||
_aqlm_available = _is_package_available("aqlm")
|
||||
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
|
||||
_av_available = importlib.util.find_spec("av") is not None
|
||||
@@ -403,6 +404,10 @@ def is_galore_torch_available():
|
||||
return _galore_torch_available
|
||||
|
||||
|
||||
def is_apollo_torch_available():
|
||||
return _apollo_torch_available
|
||||
|
||||
|
||||
def is_lomo_available():
|
||||
return _lomo_available
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_apollo_torch,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
@@ -2259,6 +2260,168 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# warm up steps << total steps
|
||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_extra_args(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw",
|
||||
optim_args="proj=random,scale_type=tensor,rank=1,update_proj_gap=100,scale=128.0",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_layerwise(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw_layerwise",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_layerwise_with_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw_layerwise",
|
||||
lr_scheduler_type="cosine",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_lr_display_without_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
learning_rate = 1e-9
|
||||
num_steps = 10
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=learning_rate,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||
|
||||
# reflects displayed lr in trainer
|
||||
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_lr_display_with_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
learning_rate = 2e-4
|
||||
num_train_epochs = 10
|
||||
num_warmup_steps = 5
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
num_train_epochs=num_train_epochs,
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=num_warmup_steps,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=1,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# creating log history of trainer, results don't matter
|
||||
trainer.train()
|
||||
logs = trainer.state.log_history[1:][:-1]
|
||||
|
||||
# reach given learning rate peak and end with 0 lr
|
||||
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
|
||||
self.assertTrue(logs[-1]["learning_rate"] == 0)
|
||||
|
||||
# increasing and decreasing pattern of lrs
|
||||
increasing_lrs = [
|
||||
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
|
||||
for i in range(len(logs))
|
||||
if i < num_warmup_steps - 2
|
||||
]
|
||||
decreasing_lrs = [
|
||||
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
|
||||
for i in range(len(logs) - 1)
|
||||
if i >= num_warmup_steps - 2
|
||||
]
|
||||
|
||||
self.assertTrue(all(increasing_lrs))
|
||||
self.assertTrue(all(decreasing_lrs))
|
||||
|
||||
# warm up steps << total steps
|
||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||
model = RegressionModel()
|
||||
|
||||
Reference in New Issue
Block a user