Fix galore lr display with schedulers (#31710)
* fix galore lr display with lr schedulers * style * add some tests to check for displayed lrs * copy-paste err for warmup steps * standardize the default lr to be only in the optimizer * trying out my luck with the reads
This commit is contained in:
@@ -519,7 +519,7 @@ def get_scheduler(
|
|||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||||
|
|
||||||
return LayerWiseDummyScheduler()
|
return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
|
||||||
|
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT:
|
||||||
return schedule_func(optimizer)
|
return schedule_func(optimizer)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import warnings
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from itertools import chain
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
@@ -1379,13 +1380,24 @@ class LayerWiseDummyScheduler(LRScheduler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
optimizer = LayerWiseDummyOptimizer()
|
self.default_lr = kwargs["lr"]
|
||||||
|
optimizer = LayerWiseDummyOptimizer(**kwargs)
|
||||||
last_epoch = -1
|
last_epoch = -1
|
||||||
verbose = False
|
verbose = False
|
||||||
super().__init__(optimizer, last_epoch, verbose)
|
super().__init__(optimizer, last_epoch, verbose)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
# default value
|
||||||
|
lrs = [self.default_lr]
|
||||||
|
|
||||||
|
# we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer`
|
||||||
|
if self.optimizer is not None:
|
||||||
|
param_wise_lrs = [
|
||||||
|
[group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values()
|
||||||
|
]
|
||||||
|
lrs = list(chain(*param_wise_lrs))
|
||||||
|
|
||||||
|
return lrs
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
def _get_closed_form_lr(self):
|
||||||
return self.base_lrs
|
return self.base_lrs
|
||||||
|
|||||||
@@ -1653,6 +1653,84 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_lr_display_without_scheduler(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
learning_rate = 1e-9
|
||||||
|
num_steps = 10
|
||||||
|
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="galore_adamw",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||||
|
|
||||||
|
# reflects displayed lr in trainer
|
||||||
|
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
||||||
|
|
||||||
|
@require_galore_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_galore_lr_display_with_scheduler(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
learning_rate = 2e-4
|
||||||
|
num_train_epochs = 2
|
||||||
|
num_warmup_steps = 5
|
||||||
|
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
num_train_epochs=num_train_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
warmup_steps=num_warmup_steps,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
logging_steps=1,
|
||||||
|
optim="galore_adamw",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# creating log history of trainer, results don't matter
|
||||||
|
trainer.train()
|
||||||
|
logs = trainer.state.log_history[1:][:-1]
|
||||||
|
|
||||||
|
# reach given learning rate peak and end with 0 lr
|
||||||
|
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
|
||||||
|
self.assertTrue(logs[-1]["learning_rate"] == 0)
|
||||||
|
|
||||||
|
# increasing and decreasing pattern of lrs
|
||||||
|
increasing_lrs = [
|
||||||
|
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
|
||||||
|
for i in range(len(logs))
|
||||||
|
if i < num_warmup_steps - 2
|
||||||
|
]
|
||||||
|
decreasing_lrs = [
|
||||||
|
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
|
||||||
|
for i in range(len(logs) - 1)
|
||||||
|
if i >= num_warmup_steps - 2
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertTrue(all(increasing_lrs))
|
||||||
|
self.assertTrue(all(decreasing_lrs))
|
||||||
|
|
||||||
|
# warm up steps << total steps
|
||||||
|
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||||
|
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||||
model = RegressionModel()
|
model = RegressionModel()
|
||||||
|
|||||||
Reference in New Issue
Block a user