Optim: APOLLO optimizer integration (#36062)
* Added APOLLO optimizer integration * fix comment * Remove redundancy: Modularize low-rank optimizer construction * Remove redundancy: Remove useless comment * Fix comment: Add typing * Fix comment: Rewrite apollo desc
This commit is contained in:
@@ -66,6 +66,7 @@ from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_apollo_torch,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
@@ -2259,6 +2260,168 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# warm up steps << total steps
|
||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_extra_args(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw",
|
||||
optim_args="proj=random,scale_type=tensor,rank=1,update_proj_gap=100,scale=128.0",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_layerwise(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw_layerwise",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_layerwise_with_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw_layerwise",
|
||||
lr_scheduler_type="cosine",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_lr_display_without_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
learning_rate = 1e-9
|
||||
num_steps = 10
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=learning_rate,
|
||||
logging_steps=5,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||
|
||||
# reflects displayed lr in trainer
|
||||
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_gpu
|
||||
def test_apollo_lr_display_with_scheduler(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
learning_rate = 2e-4
|
||||
num_train_epochs = 10
|
||||
num_warmup_steps = 5
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
num_train_epochs=num_train_epochs,
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=num_warmup_steps,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=1,
|
||||
optim="apollo_adamw",
|
||||
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# creating log history of trainer, results don't matter
|
||||
trainer.train()
|
||||
logs = trainer.state.log_history[1:][:-1]
|
||||
|
||||
# reach given learning rate peak and end with 0 lr
|
||||
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
|
||||
self.assertTrue(logs[-1]["learning_rate"] == 0)
|
||||
|
||||
# increasing and decreasing pattern of lrs
|
||||
increasing_lrs = [
|
||||
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
|
||||
for i in range(len(logs))
|
||||
if i < num_warmup_steps - 2
|
||||
]
|
||||
decreasing_lrs = [
|
||||
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
|
||||
for i in range(len(logs) - 1)
|
||||
if i >= num_warmup_steps - 2
|
||||
]
|
||||
|
||||
self.assertTrue(all(increasing_lrs))
|
||||
self.assertTrue(all(decreasing_lrs))
|
||||
|
||||
# warm up steps << total steps
|
||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||
model = RegressionModel()
|
||||
|
||||
Reference in New Issue
Block a user