feat: add flexible Liger Kernel configuration to TrainingArguments (#38911)
* feat: add flexible Liger Kernel configuration to TrainingArguments
Add support for granular Liger Kernel configuration through a new
`liger_kernel_config` parameter in TrainingArguments. This allows users
to selectively enable/disable specific kernels (rope, swiglu, cross_entropy,
etc.) instead of the current approach that rely on default configuration.
Features:
- Add `liger_kernel_config` dict parameter to TrainingArguments
- Support selective kernel application for all supported models
- Maintain full backward compatibility with existing `use_liger_kernel` flag
Example usage:
```python
TrainingArguments(
use_liger_kernel=True,
liger_kernel_config={
"rope": True,
"swiglu": True,
"cross_entropy": False,
"fused_linear_cross_entropy": True
}
)
Closes #38905
* Address comments and update Liger section in Trainer docs
This commit is contained in:
committed by
GitHub
parent
89b35be618
commit
797860c68c
@@ -1792,6 +1792,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
||||
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||
|
||||
@require_liger_kernel
|
||||
def test_use_liger_kernel_custom_config_patching(self):
|
||||
# Ensure any monkey patching is cleaned up for subsequent tests
|
||||
with patch("transformers.models.llama.modeling_llama"):
|
||||
from liger_kernel.transformers import LigerRMSNorm
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
use_liger_kernel=True,
|
||||
liger_kernel_config={"rms_norm": False}, # Don't apply Liger's RMSNorm
|
||||
)
|
||||
Trainer(tiny_llama, args)
|
||||
|
||||
# Check that the RMSNorm kernel is not applied as specified in the config
|
||||
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||
|
||||
@require_liger_kernel
|
||||
@require_torch_accelerator
|
||||
def test_use_liger_kernel_trainer(self):
|
||||
@@ -1810,6 +1829,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_liger_kernel
|
||||
@require_torch_accelerator
|
||||
def test_use_liger_kernel_custom_config_trainer(self):
|
||||
# Check that trainer still works with liger kernel applied when using a custom config
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-2,
|
||||
logging_steps=5,
|
||||
max_steps=20,
|
||||
use_liger_kernel=True,
|
||||
liger_kernel_config={"rms_norm": False, "cross_entropy": True, "fused_linear_cross_entropy": False},
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_lomo
|
||||
@require_torch_accelerator
|
||||
def test_lomo(self):
|
||||
|
||||
Reference in New Issue
Block a user