Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer (#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
Jason (Siyu) Zhu
2024-08-23 04:20:49 -07:00
committed by GitHub
parent 970a16ec7f
commit adb91179b9
7 changed files with 118 additions and 0 deletions

View File

@@ -64,6 +64,7 @@ from transformers.testing_utils import (
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_liger_kernel,
require_lomo,
require_optuna,
require_peft,
@@ -1325,6 +1326,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(first_dataloader, first_dataloader_repeated)
self.assertEqual(second_dataloader, second_dataloader_repeated)
@require_liger_kernel
def test_use_liger_kernel_patching(self):
# Test that the model code actually gets patched with Liger kernel
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from transformers.models.llama import modeling_llama
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_llama, args)
# Check that one of the Llama model layers has been correctly patched with Liger kernel
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
@require_liger_kernel
@require_torch_gpu
def test_use_liger_kernel_trainer(self):
# Check that trainer still works with liger kernel applied
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, max_steps=20, use_liger_kernel=True)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_lomo
@require_torch_gpu
def test_lomo(self):