FEAT / Optim: Add GaLore optimizer (#29588)

* add galore v1

* add import

* add tests and doc

* fix doctest

* forward contrib credits from discussions

* forward contrib credits from discussions

* Apply suggestions from code review

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix failing tests'

* switch to `optim_target_modules` and clarify docs

* more clarification

* enhance lookup logic

* update a test to add peak memory

* add regex, all-linear and single string support

* add layer-wise optimization through DummyOptimizers and LRSchedulers

* forward contrib credits from discussions and original idea

* add a section about DDP not supported in layerwise

* Update src/transformers/trainer.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix self

* check only if layer_wise

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* oops

* make use of intervals

* clarify comment

* add matching tests

* GaLoRe -> GaLore

* move to `get_scheduler`

* add note on docs

* add a warning

* adapt a bit the docs

* update docstring

* support original API

* Update docs/source/en/trainer.md

* slightly refactor

* Update docs/source/en/trainer.md

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix args parsing and add tests

* remove warning for regex

* fix type hint

* add note about extra args

* make `is_regex` return optional

---------

Co-authored-by: Maxime <maximegmd @users.noreply.github.com>
Co-authored-by: Wing Lian <winglian @users.noreply.github.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: hiyouga <hiyouga@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2024-03-19 11:40:23 +01:00
committed by GitHub
parent 484e10f7f2
commit f6261d7d81
10 changed files with 734 additions and 3 deletions

View File

@@ -60,6 +60,7 @@ from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_optuna,
require_peft,
@@ -84,7 +85,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
@@ -114,6 +115,8 @@ if is_torch_available():
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
Trainer,
TrainerState,
@@ -146,6 +149,31 @@ class RegressionDataset:
return result
# Converting Bytes to Megabytes
def bytes2megabytes(x):
return int(x / 2**20)
# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68
class TorchTracemalloc:
def __enter__(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
return self
def __exit__(self, *exc):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.used = bytes2megabytes(self.end - self.begin)
self.peaked = bytes2megabytes(self.peak - self.begin)
@dataclasses.dataclass
class RegressionTrainingArguments(TrainingArguments):
a: float = 0.0
@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
trainer.evaluate()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertTrue(is_regex)
exact_patterns = ["q_proj", "up_proj"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
simple_regex = r".*.attn.*"
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, False]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertTrue(is_regex)
simple_regex = "model.transformer.h.0.attn.q_proj"
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, False]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
target_modules = ["attn", "mlp"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
@require_galore_torch
@require_torch_gpu
def test_galore(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
optim_args="rank=64, update_proj_gap=100, scale=0.10",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_layerwise",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_layerwise",
lr_scheduler_type="cosine",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_8bit",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor_attention_only(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules=["q_proj", "k_proj", "v_proj"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor_all_linear(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules="all-linear",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()