FEAT [Trainer / bnb]: Add RMSProp from bitsandbytes to HF Trainer (#29082)
* add RMSProp to Trainer * revert some change * Update src/transformers/trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -58,6 +58,7 @@ from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_optuna,
|
||||
@@ -872,6 +873,56 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_rmsprop_bnb(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb"
|
||||
)
|
||||
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||
|
||||
# Check that it trains without errors
|
||||
trainer.train()
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_rmsprop_bnb_8bit(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_8bit"
|
||||
)
|
||||
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||
|
||||
# Check that it trains without errors
|
||||
trainer.train()
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_rmsprop_bnb_32bit(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_32bit"
|
||||
)
|
||||
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||
|
||||
# Check that it trains without errors
|
||||
trainer.train()
|
||||
|
||||
def test_neftune(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
|
||||
Reference in New Issue
Block a user