FEAT / Optim: Add GaLore optimizer (#29588)
* add galore v1 * add import * add tests and doc * fix doctest * forward contrib credits from discussions * forward contrib credits from discussions * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix failing tests' * switch to `optim_target_modules` and clarify docs * more clarification * enhance lookup logic * update a test to add peak memory * add regex, all-linear and single string support * add layer-wise optimization through DummyOptimizers and LRSchedulers * forward contrib credits from discussions and original idea * add a section about DDP not supported in layerwise * Update src/transformers/trainer.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix self * check only if layer_wise * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * oops * make use of intervals * clarify comment * add matching tests * GaLoRe -> GaLore * move to `get_scheduler` * add note on docs * add a warning * adapt a bit the docs * update docstring * support original API * Update docs/source/en/trainer.md * slightly refactor * Update docs/source/en/trainer.md Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix args parsing and add tests * remove warning for regex * fix type hint * add note about extra args * make `is_regex` return optional --------- Co-authored-by: Maxime <maximegmd @users.noreply.github.com> Co-authored-by: Wing Lian <winglian @users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: hiyouga <hiyouga@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
This commit is contained in:
@@ -60,6 +60,7 @@ from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_optuna,
|
||||
require_peft,
|
||||
@@ -84,7 +85,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -114,6 +115,8 @@ if is_torch_available():
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
LineByLineTextDataset,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerState,
|
||||
@@ -146,6 +149,31 @@ class RegressionDataset:
|
||||
return result
|
||||
|
||||
|
||||
# Converting Bytes to Megabytes
|
||||
def bytes2megabytes(x):
|
||||
return int(x / 2**20)
|
||||
|
||||
|
||||
# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68
|
||||
class TorchTracemalloc:
|
||||
def __enter__(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
|
||||
self.begin = torch.cuda.memory_allocated()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
self.end = torch.cuda.memory_allocated()
|
||||
self.peak = torch.cuda.max_memory_allocated()
|
||||
self.used = bytes2megabytes(self.end - self.begin)
|
||||
self.peaked = bytes2megabytes(self.peak - self.begin)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegressionTrainingArguments(TrainingArguments):
|
||||
a: float = 0.0
|
||||
@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertTrue(is_regex)
|
||||
|
||||
exact_patterns = ["q_proj", "up_proj"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
simple_regex = r".*.attn.*"
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, False]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertTrue(is_regex)
|
||||
|
||||
simple_regex = "model.transformer.h.0.attn.q_proj"
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, False]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
target_modules = ["attn", "mlp"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_extra_args(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw",
|
||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_layerwise(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw_layerwise",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_layerwise_with_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw_layerwise",
|
||||
lr_scheduler_type="cosine",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adamw_8bit(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adamw_8bit",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adafactor(self):
|
||||
# These are the intervals of the peak memory usage of training such a tiny model
|
||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adafactor_attention_only(self):
|
||||
# These are the intervals of the peak memory usage of training such a tiny model
|
||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules=["q_proj", "k_proj", "v_proj"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_gpu
|
||||
def test_galore_adafactor_all_linear(self):
|
||||
# These are the intervals of the peak memory usage of training such a tiny model
|
||||
# if the peak memory goes outside that range, then we know there might be a bug somewhere
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules="all-linear",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||
model = RegressionModel()
|
||||
|
||||
Reference in New Issue
Block a user