From a01b033cb45bb5ad77a2aa676367ef764b92e038 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Fri, 5 Jul 2024 19:59:09 +0200 Subject: [PATCH] Fix galore lr display with schedulers (#31710) * fix galore lr display with lr schedulers * style * add some tests to check for displayed lrs * copy-paste err for warmup steps * standardize the default lr to be only in the optimizer * trying out my luck with the reads --- src/transformers/optimization.py | 2 +- src/transformers/trainer_pt_utils.py | 16 +++++- tests/trainer/test_trainer.py | 78 ++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 3 deletions(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index a462e3d824..0ca5d36d0f 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -519,7 +519,7 @@ def get_scheduler( if param.requires_grad: param.register_post_accumulate_grad_hook(scheduler_hook) - return LayerWiseDummyScheduler() + return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"]) if name == SchedulerType.CONSTANT: return schedule_func(optimizer) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 856ba4f664..fcffcd3595 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -27,6 +27,7 @@ import warnings from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass, field +from itertools import chain from logging import StreamHandler from typing import Any, Dict, Iterator, List, Optional, Union @@ -1379,13 +1380,24 @@ class LayerWiseDummyScheduler(LRScheduler): """ def __init__(self, *args, **kwargs): - optimizer = LayerWiseDummyOptimizer() + self.default_lr = kwargs["lr"] + optimizer = LayerWiseDummyOptimizer(**kwargs) last_epoch = -1 verbose = False super().__init__(optimizer, last_epoch, verbose) def get_lr(self): - return [group["lr"] for group in self.optimizer.param_groups] + # default value + lrs = [self.default_lr] + + # we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer` + if self.optimizer is not None: + param_wise_lrs = [ + [group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values() + ] + lrs = list(chain(*param_wise_lrs)) + + return lrs def _get_closed_form_lr(self): return self.base_lrs diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 26fa462467..e31e6cb822 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1653,6 +1653,84 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(galore_peak_memory < upper_bound_pm) self.assertTrue(lower_bound_pm < galore_peak_memory) + @require_galore_torch + @require_torch_gpu + def test_galore_lr_display_without_scheduler(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + learning_rate = 1e-9 + num_steps = 10 + + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=learning_rate, + logging_steps=5, + optim="galore_adamw", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + trainer.create_optimizer_and_scheduler(num_training_steps=num_steps) + + # reflects displayed lr in trainer + self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate]) + + @require_galore_torch + @require_torch_gpu + def test_galore_lr_display_with_scheduler(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + learning_rate = 2e-4 + num_train_epochs = 2 + num_warmup_steps = 5 + + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + num_train_epochs=num_train_epochs, + learning_rate=learning_rate, + warmup_steps=num_warmup_steps, + lr_scheduler_type="cosine", + logging_steps=1, + optim="galore_adamw", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # creating log history of trainer, results don't matter + trainer.train() + logs = trainer.state.log_history[1:][:-1] + + # reach given learning rate peak and end with 0 lr + self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate) + self.assertTrue(logs[-1]["learning_rate"] == 0) + + # increasing and decreasing pattern of lrs + increasing_lrs = [ + logs[i]["learning_rate"] < logs[i + 1]["learning_rate"] + for i in range(len(logs)) + if i < num_warmup_steps - 2 + ] + decreasing_lrs = [ + logs[i]["learning_rate"] > logs[i + 1]["learning_rate"] + for i in range(len(logs) - 1) + if i >= num_warmup_steps - 2 + ] + + self.assertTrue(all(increasing_lrs)) + self.assertTrue(all(decreasing_lrs)) + + # warm up steps << total steps + self.assertTrue(len(decreasing_lrs) > len(increasing_lrs)) + @require_torch_multi_accelerator def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel()