Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer (#32860)
* add liger integration * fix syntax * fix import issue * add trainer.md * Use _apply_liger_kernel() * Fixed log message * Update docs/source/en/trainer.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update docs/source/en/trainer.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> * Update docs/source/en/trainer.md Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> * Fixed checkstyle and updated readme * Added test * Fixed checkstyle * fix docstring * rename use_liger to use_liger_kernel * Trigger Build * Added test * add fix-copies * Fixed copy inconsistencies --------- Co-authored-by: shimizust <sshimizu@linkedin.com> Co-authored-by: Steven Shimizu <shimizust@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
@@ -64,6 +64,7 @@ from transformers.testing_utils import (
|
||||
require_galore_torch,
|
||||
require_grokadamw,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_liger_kernel,
|
||||
require_lomo,
|
||||
require_optuna,
|
||||
require_peft,
|
||||
@@ -1325,6 +1326,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(first_dataloader, first_dataloader_repeated)
|
||||
self.assertEqual(second_dataloader, second_dataloader_repeated)
|
||||
|
||||
@require_liger_kernel
|
||||
def test_use_liger_kernel_patching(self):
|
||||
# Test that the model code actually gets patched with Liger kernel
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
args = TrainingArguments(
|
||||
"./test",
|
||||
use_liger_kernel=True,
|
||||
)
|
||||
Trainer(tiny_llama, args)
|
||||
|
||||
# Check that one of the Llama model layers has been correctly patched with Liger kernel
|
||||
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
|
||||
|
||||
@require_liger_kernel
|
||||
@require_torch_gpu
|
||||
def test_use_liger_kernel_trainer(self):
|
||||
# Check that trainer still works with liger kernel applied
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, max_steps=20, use_liger_kernel=True)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_lomo
|
||||
@require_torch_gpu
|
||||
def test_lomo(self):
|
||||
|
||||
Reference in New Issue
Block a user